diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 6c1bb6eaa5..eec35fd62f 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -8,6 +8,9 @@ on: release: types: [published] +permissions: + contents: read + concurrency: group: docker-${{ github.ref }} cancel-in-progress: true @@ -17,22 +20,29 @@ jobs: # Only run on the upstream repository, not on forks if: github.repository == 'NousResearch/hermes-agent' runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Build image + # Build amd64 only so we can `load` the image for smoke testing. + # `load: true` cannot export a multi-arch manifest to the local daemon. + # The multi-arch build follows on push to main / release. + - name: Build image (amd64, smoke test) uses: docker/build-push-action@v6 with: context: . file: Dockerfile load: true + platforms: linux/amd64 tags: nousresearch/hermes-agent:test cache-from: type=gha cache-to: type=gha,mode=max @@ -51,26 +61,28 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Push image (main branch) + - name: Push multi-arch image (main branch) if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: docker/build-push-action@v6 with: context: . file: Dockerfile push: true + platforms: linux/amd64,linux/arm64 tags: | nousresearch/hermes-agent:latest nousresearch/hermes-agent:${{ github.sha }} cache-from: type=gha cache-to: type=gha,mode=max - - name: Push image (release) + - name: Push multi-arch image (release) if: github.event_name == 'release' uses: docker/build-push-action@v6 with: context: . file: Dockerfile push: true + platforms: linux/amd64,linux/arm64 tags: | nousresearch/hermes-agent:latest nousresearch/hermes-agent:${{ github.event.release.tag_name }} diff --git a/.github/workflows/docs-site-checks.yml b/.github/workflows/docs-site-checks.yml index 14cdb8f6a6..ea05d28046 100644 --- a/.github/workflows/docs-site-checks.yml +++ b/.github/workflows/docs-site-checks.yml @@ -27,8 +27,8 @@ jobs: with: python-version: '3.11' - - name: Install Python dependencies - run: python -m pip install ascii-guard pyyaml + - name: Install ascii-guard + run: python -m pip install ascii-guard==2.3.0 pyyaml==6.0.3 - name: Extract skill metadata for dashboard run: python3 website/scripts/extract-skills.py diff --git a/.github/workflows/nix.yml b/.github/workflows/nix.yml index 004f8236a2..dba33bfffc 100644 --- a/.github/workflows/nix.yml +++ b/.github/workflows/nix.yml @@ -27,8 +27,8 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@v4 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: DeterminateSystems/nix-installer-action@ef8a148080ab6020fd15196c2084a2eea5ff2d25 # v22 + - uses: DeterminateSystems/magic-nix-cache-action@565684385bcd71bad329742eefe8d12f2e765b39 # v13 - name: Check flake if: runner.os == 'Linux' run: nix flake check --print-build-logs diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index f743a64eeb..27c67c10a3 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -629,11 +629,19 @@ def _nous_base_url() -> str: def _read_codex_access_token() -> Optional[str]: - """Read a valid, non-expired Codex OAuth access token from Hermes auth store.""" + """Read a valid, non-expired Codex OAuth access token from Hermes auth store. + + If a credential pool exists but currently has no selectable runtime entry + (for example all pool slots are marked exhausted), fall back to the + profile's auth.json token instead of hard-failing. This keeps explicit + fallback-to-Codex working when the pool state is stale but the stored OAuth + token is still valid. + """ pool_present, entry = _select_pool_entry("openai-codex") if pool_present: token = _pool_runtime_api_key(entry) - return token or None + if token: + return token try: from hermes_cli.auth import _read_codex_tokens @@ -894,9 +902,13 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]: pool_present, entry = _select_pool_entry("openai-codex") if pool_present: codex_token = _pool_runtime_api_key(entry) - if not codex_token: - return None, None - base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL + if codex_token: + base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL + else: + codex_token = _read_codex_access_token() + if not codex_token: + return None, None + base_url = _CODEX_AUX_BASE_URL else: codex_token = _read_codex_access_token() if not codex_token: diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 0d971e4b56..c61cf2c5a7 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -154,12 +154,15 @@ class ContextCompressor: def _prune_old_tool_results( self, messages: List[Dict[str, Any]], protect_tail_count: int, + protect_tail_tokens: int | None = None, ) -> tuple[List[Dict[str, Any]], int]: """Replace old tool result contents with a short placeholder. - Walks backward from the end, protecting the most recent - ``protect_tail_count`` messages. Older tool results get their - content replaced with a placeholder string. + Walks backward from the end, protecting the most recent messages that + fall within ``protect_tail_tokens`` (when provided) OR the last + ``protect_tail_count`` messages (backward-compatible default). + When both are given, the token budget takes priority and the message + count acts as a hard minimum floor. Returns (pruned_messages, pruned_count). """ @@ -168,7 +171,29 @@ class ContextCompressor: result = [m.copy() for m in messages] pruned = 0 - prune_boundary = len(result) - protect_tail_count + + # Determine the prune boundary + if protect_tail_tokens is not None and protect_tail_tokens > 0: + # Token-budget approach: walk backward accumulating tokens + accumulated = 0 + boundary = len(result) + min_protect = min(protect_tail_count, len(result) - 1) + for i in range(len(result) - 1, -1, -1): + msg = result[i] + content_len = len(msg.get("content") or "") + msg_tokens = content_len // _CHARS_PER_TOKEN + 10 + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict): + args = tc.get("function", {}).get("arguments", "") + msg_tokens += len(args) // _CHARS_PER_TOKEN + if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect: + boundary = i + break + accumulated += msg_tokens + boundary = i + prune_boundary = max(boundary, len(result) - min_protect) + else: + prune_boundary = len(result) - protect_tail_count for i in range(prune_boundary): msg = result[i] @@ -199,30 +224,39 @@ class ContextCompressor: budget = int(content_tokens * _SUMMARY_RATIO) return max(_MIN_SUMMARY_TOKENS, min(budget, self.max_summary_tokens)) + # Truncation limits for the summarizer input. These bound how much of + # each message the summary model sees — the budget is the *summary* + # model's context window, not the main model's. + _CONTENT_MAX = 6000 # total chars per message body + _CONTENT_HEAD = 4000 # chars kept from the start + _CONTENT_TAIL = 1500 # chars kept from the end + _TOOL_ARGS_MAX = 1500 # tool call argument chars + _TOOL_ARGS_HEAD = 1200 # kept from the start of tool args + def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str: """Serialize conversation turns into labeled text for the summarizer. - Includes tool call arguments and result content (up to 3000 chars - per message) so the summarizer can preserve specific details like - file paths, commands, and outputs. + Includes tool call arguments and result content (up to + ``_CONTENT_MAX`` chars per message) so the summarizer can preserve + specific details like file paths, commands, and outputs. """ parts = [] for msg in turns: role = msg.get("role", "unknown") content = msg.get("content") or "" - # Tool results: keep more content than before (3000 chars) + # Tool results: keep enough content for the summarizer if role == "tool": tool_id = msg.get("tool_call_id", "") - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] parts.append(f"[TOOL RESULT {tool_id}]: {content}") continue # Assistant messages: include tool call names AND arguments if role == "assistant": - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] tool_calls = msg.get("tool_calls", []) if tool_calls: tc_parts = [] @@ -232,8 +266,8 @@ class ContextCompressor: name = fn.get("name", "?") args = fn.get("arguments", "") # Truncate long arguments but keep enough for context - if len(args) > 500: - args = args[:400] + "..." + if len(args) > self._TOOL_ARGS_MAX: + args = args[:self._TOOL_ARGS_HEAD] + "..." tc_parts.append(f" {name}({args})") else: fn = getattr(tc, "function", None) @@ -244,8 +278,8 @@ class ContextCompressor: continue # User and other roles - if len(content) > 3000: - content = content[:2000] + "\n...[truncated]...\n" + content[-800:] + if len(content) > self._CONTENT_MAX: + content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:] parts.append(f"[{role.upper()}]: {content}") return "\n\n".join(parts) @@ -310,6 +344,9 @@ Update the summary using this exact structure. PRESERVE all existing information ## Critical Context [Any specific values, error messages, configuration details, or data that would be lost without explicit preservation] +## Tools & Patterns +[Which tools were used, how they were used effectively, and any tool-specific discoveries. Accumulate across compactions.] + Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. Write only the summary body. Do not include any preamble or prefix.""" @@ -348,6 +385,9 @@ Use this exact structure: ## Critical Context [Any specific values, error messages, configuration details, or data that would be lost without explicit preservation] +## Tools & Patterns +[Which tools were used, how they were used effectively, and any tool-specific discoveries (e.g., preferred flags, working invocations, successful command patterns)] + Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details. Write only the summary body. Do not include any preamble or prefix.""" @@ -518,13 +558,20 @@ Write only the summary body. Do not include any preamble or prefix.""" derived from ``summary_target_ratio * context_length``, so it scales automatically with the model's context window. - Never cuts inside a tool_call/result group. Falls back to the old - ``protect_last_n`` if the budget would protect fewer messages. + Token budget is the primary criterion. A hard minimum of 3 messages + is always protected, but the budget is allowed to exceed by up to + 1.5x to avoid cutting inside an oversized message (tool output, file + read, etc.). If even the minimum 3 messages exceed 1.5x the budget + the cut is placed right after the head so compression still runs. + + Never cuts inside a tool_call/result group. """ if token_budget is None: token_budget = self.tail_token_budget n = len(messages) - min_tail = self.protect_last_n + # Hard minimum: always keep at least 3 messages in the tail + min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0 + soft_ceiling = int(token_budget * 1.5) accumulated = 0 cut_idx = n # start from beyond the end @@ -537,21 +584,21 @@ Write only the summary body. Do not include any preamble or prefix.""" if isinstance(tc, dict): args = tc.get("function", {}).get("arguments", "") msg_tokens += len(args) // _CHARS_PER_TOKEN - if accumulated + msg_tokens > token_budget and (n - i) >= min_tail: + # Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet) + if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail: break accumulated += msg_tokens cut_idx = i - # Ensure we protect at least protect_last_n messages + # Ensure we protect at least min_tail messages fallback_cut = n - min_tail if cut_idx > fallback_cut: cut_idx = fallback_cut # If the token budget would protect everything (small conversations), - # fall back to the fixed protect_last_n approach so compression can - # still remove middle turns. + # force a cut after the head so compression can still remove middle turns. if cut_idx <= head_end: - cut_idx = fallback_cut + cut_idx = max(fallback_cut, head_end + 1) # Align to avoid splitting tool groups cut_idx = self._align_boundary_backward(messages, cut_idx) @@ -576,12 +623,13 @@ Write only the summary body. Do not include any preamble or prefix.""" up so the API never receives mismatched IDs. """ n_messages = len(messages) - if n_messages <= self.protect_first_n + self.protect_last_n + 1: + # Only need head + 3 tail messages minimum (token budget decides the real tail size) + _min_for_compress = self.protect_first_n + 3 + 1 + if n_messages <= _min_for_compress: if not self.quiet_mode: logger.warning( "Cannot compress: only %d messages (need > %d)", - n_messages, - self.protect_first_n + self.protect_last_n + 1, + n_messages, _min_for_compress, ) return messages @@ -589,7 +637,8 @@ Write only the summary body. Do not include any preamble or prefix.""" # Phase 1: Prune old tool results (cheap, no LLM call) messages, pruned_count = self._prune_old_tool_results( - messages, protect_tail_count=self.protect_last_n * 3, + messages, protect_tail_count=self.protect_last_n, + protect_tail_tokens=self.tail_token_budget, ) if pruned_count and not self.quiet_mode: logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count) diff --git a/agent/credential_pool.py b/agent/credential_pool.py index a47901c847..dd2c9abc5e 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -64,10 +64,10 @@ SUPPORTED_POOL_STRATEGIES = { } # Cooldown before retrying an exhausted credential. -# 429 (rate-limited) cools down faster since quotas reset frequently. -# 402 (billing/quota) and other codes use a longer default. +# 429 (rate-limited) and 402 (billing/quota) both cool down after 1 hour. +# Provider-supplied reset_at timestamps override these defaults. EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour -EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours +EXHAUSTED_TTL_DEFAULT_SECONDS = 60 * 60 # 1 hour # Pool key prefix for custom OpenAI-compatible endpoints. # Custom endpoints all share provider='custom' but are keyed by their diff --git a/agent/error_classifier.py b/agent/error_classifier.py new file mode 100644 index 0000000000..b227932ad7 --- /dev/null +++ b/agent/error_classifier.py @@ -0,0 +1,789 @@ +"""API error classification for smart failover and recovery. + +Provides a structured taxonomy of API errors and a priority-ordered +classification pipeline that determines the correct recovery action +(retry, rotate credential, fallback to another provider, compress +context, or abort). + +Replaces scattered inline string-matching with a centralized classifier +that the main retry loop in run_agent.py consults for every API failure. +""" + +from __future__ import annotations + +import enum +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +# ── Error taxonomy ────────────────────────────────────────────────────── + +class FailoverReason(enum.Enum): + """Why an API call failed — determines recovery strategy.""" + + # Authentication / authorization + auth = "auth" # Transient auth (401/403) — refresh/rotate + auth_permanent = "auth_permanent" # Auth failed after refresh — abort + + # Billing / quota + billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately + rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate + + # Server-side + overloaded = "overloaded" # 503/529 — provider overloaded, backoff + server_error = "server_error" # 500/502 — internal server error, retry + + # Transport + timeout = "timeout" # Connection/read timeout — rebuild client + retry + + # Context / payload + context_overflow = "context_overflow" # Context too large — compress, not failover + payload_too_large = "payload_too_large" # 413 — compress payload + + # Model + model_not_found = "model_not_found" # 404 or invalid model — fallback to different model + + # Request format + format_error = "format_error" # 400 bad request — abort or strip + retry + + # Provider-specific + thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid + long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate + + # Catch-all + unknown = "unknown" # Unclassifiable — retry with backoff + + +# ── Classification result ─────────────────────────────────────────────── + +@dataclass +class ClassifiedError: + """Structured classification of an API error with recovery hints.""" + + reason: FailoverReason + status_code: Optional[int] = None + provider: Optional[str] = None + model: Optional[str] = None + message: str = "" + error_context: Dict[str, Any] = field(default_factory=dict) + + # Recovery action hints — the retry loop checks these instead of + # re-classifying the error itself. + retryable: bool = True + should_compress: bool = False + should_rotate_credential: bool = False + should_fallback: bool = False + + @property + def is_auth(self) -> bool: + return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent) + + @property + def is_transient(self) -> bool: + """Error is expected to resolve on retry (with or without backoff).""" + return self.reason in ( + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ) + + +# ── Provider-specific patterns ────────────────────────────────────────── + +# Patterns that indicate billing exhaustion (not transient rate limit) +_BILLING_PATTERNS = [ + "insufficient credits", + "insufficient_quota", + "credit balance", + "credits have been exhausted", + "top up your credits", + "payment required", + "billing hard limit", + "exceeded your current quota", + "account is deactivated", + "plan does not include", +] + +# Patterns that indicate rate limiting (transient, will resolve) +_RATE_LIMIT_PATTERNS = [ + "rate limit", + "rate_limit", + "too many requests", + "throttled", + "requests per minute", + "tokens per minute", + "requests per day", + "try again in", + "please retry after", + "resource_exhausted", +] + +# Usage-limit patterns that need disambiguation (could be billing OR rate_limit) +_USAGE_LIMIT_PATTERNS = [ + "usage limit", + "quota", + "limit exceeded", + "key limit exceeded", +] + +# Patterns confirming usage limit is transient (not billing) +_USAGE_LIMIT_TRANSIENT_SIGNALS = [ + "try again", + "retry", + "resets at", + "reset in", + "wait", + "requests remaining", + "periodic", + "window", +] + +# Payload-too-large patterns detected from message text (no status_code attr). +# Proxies and some backends embed the HTTP status in the error message. +_PAYLOAD_TOO_LARGE_PATTERNS = [ + "request entity too large", + "payload too large", + "error code: 413", +] + +# Context overflow patterns +_CONTEXT_OVERFLOW_PATTERNS = [ + "context length", + "context size", + "maximum context", + "token limit", + "too many tokens", + "reduce the length", + "exceeds the limit", + "context window", + "prompt is too long", + "prompt exceeds max length", + "max_tokens", + "maximum number of tokens", + # Chinese error messages (some providers return these) + "超过最大长度", + "上下文长度", +] + +# Model not found patterns +_MODEL_NOT_FOUND_PATTERNS = [ + "is not a valid model", + "invalid model", + "model not found", + "model_not_found", + "does not exist", + "no such model", + "unknown model", + "unsupported model", +] + +# Auth patterns (non-status-code signals) +_AUTH_PATTERNS = [ + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "forbidden", + "invalid token", + "token expired", + "token revoked", + "access denied", +] + +# Anthropic thinking block signature patterns +_THINKING_SIG_PATTERNS = [ + "signature", # Combined with "thinking" check +] + +# Transport error type names +_TRANSPORT_ERROR_TYPES = frozenset({ + "ReadTimeout", "ConnectTimeout", "PoolTimeout", + "ConnectError", "RemoteProtocolError", + "ConnectionError", "ConnectionResetError", + "ConnectionAbortedError", "BrokenPipeError", + "TimeoutError", "ReadError", + "ServerDisconnectedError", + # OpenAI SDK errors (not subclasses of Python builtins) + "APIConnectionError", + "APITimeoutError", +}) + +# Server disconnect patterns (no status code, but transport-level) +_SERVER_DISCONNECT_PATTERNS = [ + "server disconnected", + "peer closed connection", + "connection reset by peer", + "connection was closed", + "network connection lost", + "unexpected eof", + "incomplete chunked read", +] + + +# ── Classification pipeline ───────────────────────────────────────────── + +def classify_api_error( + error: Exception, + *, + provider: str = "", + model: str = "", + approx_tokens: int = 0, + context_length: int = 200000, + num_messages: int = 0, +) -> ClassifiedError: + """Classify an API error into a structured recovery recommendation. + + Priority-ordered pipeline: + 1. Special-case provider-specific patterns (thinking sigs, tier gates) + 2. HTTP status code + message-aware refinement + 3. Error code classification (from body) + 4. Message pattern matching (billing vs rate_limit vs context vs auth) + 5. Transport error heuristics + 6. Server disconnect + large session → context overflow + 7. Fallback: unknown (retryable with backoff) + + Args: + error: The exception from the API call. + provider: Current provider name (e.g. "openrouter", "anthropic"). + model: Current model slug. + approx_tokens: Approximate token count of the current context. + context_length: Maximum context length for the current model. + + Returns: + ClassifiedError with reason and recovery action hints. + """ + status_code = _extract_status_code(error) + error_type = type(error).__name__ + body = _extract_error_body(error) + error_code = _extract_error_code(body) + + # Build a comprehensive error message string for pattern matching. + # str(error) alone may not include the body message (e.g. OpenAI SDK's + # APIStatusError.__str__ returns the first arg, not the body). Append + # the body message so patterns like "try again" in 402 disambiguation + # are detected even when only present in the structured body. + # + # Also extract metadata.raw — OpenRouter wraps upstream provider errors + # inside {"error": {"message": "Provider returned error", "metadata": + # {"raw": ""}}} and the real error message (e.g. + # "context length exceeded") is only in the inner JSON. + _raw_msg = str(error).lower() + _body_msg = "" + _metadata_msg = "" + if isinstance(body, dict): + _err_obj = body.get("error", {}) + if isinstance(_err_obj, dict): + _body_msg = (_err_obj.get("message") or "").lower() + # Parse metadata.raw for wrapped provider errors + _metadata = _err_obj.get("metadata", {}) + if isinstance(_metadata, dict): + _raw_json = _metadata.get("raw") or "" + if isinstance(_raw_json, str) and _raw_json.strip(): + try: + import json + _inner = json.loads(_raw_json) + if isinstance(_inner, dict): + _inner_err = _inner.get("error", {}) + if isinstance(_inner_err, dict): + _metadata_msg = (_inner_err.get("message") or "").lower() + except (json.JSONDecodeError, TypeError): + pass + if not _body_msg: + _body_msg = (body.get("message") or "").lower() + # Combine all message sources for pattern matching + parts = [_raw_msg] + if _body_msg and _body_msg not in _raw_msg: + parts.append(_body_msg) + if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg: + parts.append(_metadata_msg) + error_msg = " ".join(parts) + provider_lower = (provider or "").strip().lower() + model_lower = (model or "").strip().lower() + + def _result(reason: FailoverReason, **overrides) -> ClassifiedError: + defaults = { + "reason": reason, + "status_code": status_code, + "provider": provider, + "model": model, + "message": _extract_message(error, body), + } + defaults.update(overrides) + return ClassifiedError(**defaults) + + # ── 1. Provider-specific patterns (highest priority) ──────────── + + # Anthropic thinking block signature invalid (400). + # Don't gate on provider — OpenRouter proxies Anthropic errors, so the + # provider may be "openrouter" even though the error is Anthropic-specific. + # The message pattern ("signature" + "thinking") is unique enough. + if ( + status_code == 400 + and "signature" in error_msg + and "thinking" in error_msg + ): + return _result( + FailoverReason.thinking_signature, + retryable=True, + should_compress=False, + ) + + # Anthropic long-context tier gate (429 "extra usage" + "long context") + if ( + status_code == 429 + and "extra usage" in error_msg + and "long context" in error_msg + ): + return _result( + FailoverReason.long_context_tier, + retryable=True, + should_compress=True, + ) + + # ── 2. HTTP status code classification ────────────────────────── + + if status_code is not None: + classified = _classify_by_status( + status_code, error_msg, error_code, body, + provider=provider_lower, model=model_lower, + approx_tokens=approx_tokens, context_length=context_length, + num_messages=num_messages, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 3. Error code classification ──────────────────────────────── + + if error_code: + classified = _classify_by_error_code(error_code, error_msg, _result) + if classified is not None: + return classified + + # ── 4. Message pattern matching (no status code) ──────────────── + + classified = _classify_by_message( + error_msg, error_type, + approx_tokens=approx_tokens, + context_length=context_length, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 5. Server disconnect + large session → context overflow ───── + # Must come BEFORE generic transport error catch — a disconnect on + # a large session is more likely context overflow than a transient + # transport hiccup. Without this ordering, RemoteProtocolError + # always maps to timeout regardless of session size. + + is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS) + if is_disconnect and not status_code: + is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200 + if is_large: + return _result( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + return _result(FailoverReason.timeout, retryable=True) + + # ── 6. Transport / timeout heuristics ─────────────────────────── + + if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)): + return _result(FailoverReason.timeout, retryable=True) + + # ── 7. Fallback: unknown ──────────────────────────────────────── + + return _result(FailoverReason.unknown, retryable=True) + + +# ── Status code classification ────────────────────────────────────────── + +def _classify_by_status( + status_code: int, + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on HTTP status code with message-aware refinement.""" + + if status_code == 401: + # Not retryable on its own — credential pool rotation and + # provider-specific refresh (Codex, Anthropic, Nous) run before + # the retryability check in run_agent.py. If those succeed, the + # loop `continue`s. If they fail, retryable=False ensures we + # hit the client-error abort path (which tries fallback first). + return result_fn( + FailoverReason.auth, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 403: + # OpenRouter 403 "key limit exceeded" is actually billing + if "key limit exceeded" in error_msg or "spending limit" in error_msg: + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + return result_fn( + FailoverReason.auth, + retryable=False, + should_fallback=True, + ) + + if status_code == 402: + return _classify_402(error_msg, result_fn) + + if status_code == 404: + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + # Generic 404 — could be model or endpoint + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if status_code == 413: + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + if status_code == 429: + # Already checked long_context_tier above; this is a normal rate limit + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 400: + return _classify_400( + error_msg, error_code, body, + provider=provider, model=model, + approx_tokens=approx_tokens, + context_length=context_length, + num_messages=num_messages, + result_fn=result_fn, + ) + + if status_code in (500, 502): + return result_fn(FailoverReason.server_error, retryable=True) + + if status_code in (503, 529): + return result_fn(FailoverReason.overloaded, retryable=True) + + # Other 4xx — non-retryable + if 400 <= status_code < 500: + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + # Other 5xx — retryable + if 500 <= status_code < 600: + return result_fn(FailoverReason.server_error, retryable=True) + + return None + + +def _classify_402(error_msg: str, result_fn) -> ClassifiedError: + """Disambiguate 402: billing exhaustion vs transient usage limit. + + The key insight from OpenClaw: some 402s are transient rate limits + disguised as payment errors. "Usage limit, try again in 5 minutes" + is NOT a billing problem — it's a periodic quota that resets. + """ + # Check for transient usage-limit signals first + has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS) + has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS) + + if has_usage_limit and has_transient_signal: + # Transient quota — treat as rate limit, not billing + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Confirmed billing exhaustion + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + +def _classify_400( + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> ClassifiedError: + """Classify 400 Bad Request — context overflow, format error, or generic.""" + + # Context overflow from 400 + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter). + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + # Some providers return rate limit / billing errors as 400 instead of 429/402. + # Check these patterns before falling through to format_error. + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Generic 400 + large session → probable context overflow + # Anthropic sometimes returns a bare "Error" message when context is too large + err_body_msg = "" + if isinstance(body, dict): + err_obj = body.get("error", {}) + if isinstance(err_obj, dict): + err_body_msg = (err_obj.get("message") or "").strip().lower() + is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "") + is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80 + + if is_generic and is_large: + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Non-retryable format error + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + +# ── Error code classification ─────────────────────────────────────────── + +def _classify_by_error_code( + error_code: str, error_msg: str, result_fn, +) -> Optional[ClassifiedError]: + """Classify by structured error codes from the response body.""" + code_lower = error_code.lower() + + if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + ) + + if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if code_lower in ("model_not_found", "model_not_available", "invalid_model"): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if code_lower in ("context_length_exceeded", "max_tokens_exceeded"): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + return None + + +# ── Message pattern classification ────────────────────────────────────── + +def _classify_by_message( + error_msg: str, + error_type: str, + *, + approx_tokens: int, + context_length: int, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on error message patterns when no status code is available.""" + + # Payload-too-large patterns (from message text when no status_code) + if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS): + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + # Billing patterns + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Rate limit patterns + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Context overflow patterns + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Auth patterns + if any(p in error_msg for p in _AUTH_PATTERNS): + return result_fn( + FailoverReason.auth, + retryable=True, + should_rotate_credential=True, + ) + + # Model not found patterns + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + return None + + +# ── Helpers ───────────────────────────────────────────────────────────── + +def _extract_status_code(error: Exception) -> Optional[int]: + """Walk the error and its cause chain to find an HTTP status code.""" + current = error + for _ in range(5): # Max depth to prevent infinite loops + code = getattr(current, "status_code", None) + if isinstance(code, int): + return code + # Some SDKs use .status instead of .status_code + code = getattr(current, "status", None) + if isinstance(code, int) and 100 <= code < 600: + return code + # Walk cause chain + cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None) + if cause is None or cause is current: + break + current = cause + return None + + +def _extract_error_body(error: Exception) -> dict: + """Extract the structured error body from an SDK exception.""" + body = getattr(error, "body", None) + if isinstance(body, dict): + return body + # Some errors have .response.json() + response = getattr(error, "response", None) + if response is not None: + try: + json_body = response.json() + if isinstance(json_body, dict): + return json_body + except Exception: + pass + return {} + + +def _extract_error_code(body: dict) -> str: + """Extract an error code string from the response body.""" + if not body: + return "" + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + code = error_obj.get("code") or error_obj.get("type") or "" + if isinstance(code, str) and code.strip(): + return code.strip() + # Top-level code + code = body.get("code") or body.get("error_code") or "" + if isinstance(code, (str, int)): + return str(code).strip() + return "" + + +def _extract_message(error: Exception, body: dict) -> str: + """Extract the most informative error message.""" + # Try structured body first + if body: + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + msg = error_obj.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + msg = body.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + # Fallback to str(error) + return str(error)[:500] diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 5b1d3376af..9282586fea 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -197,6 +197,7 @@ _URL_TO_PROVIDER: Dict[str, str] = { "api.githubcopilot.com": "copilot", "models.github.ai": "copilot", "api.fireworks.ai": "fireworks", + "opencode.ai": "opencode-go", } diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index b1b0891f59..8302973aac 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -349,6 +349,13 @@ PLATFORM_HINTS = { "only — no markdown, no formatting. SMS messages are limited to ~1600 " "characters, so be brief and direct." ), + "bluebubbles": ( + "You are chatting via iMessage (BlueBubbles). iMessage does not render " + "markdown formatting — use plain text. Keep responses concise as they " + "appear as text messages. You can send media files natively: include " + "MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, " + ".heic) appear as photos and other files arrive as attachments." + ), } CONTEXT_FILE_MAX_CHARS = 20_000 diff --git a/agent/rate_limit_tracker.py b/agent/rate_limit_tracker.py new file mode 100644 index 0000000000..c87e096a1d --- /dev/null +++ b/agent/rate_limit_tracker.py @@ -0,0 +1,242 @@ +"""Rate limit tracking for inference API responses. + +Captures x-ratelimit-* headers from provider responses and provides +formatted display for the /usage slash command. Currently supports +the Nous Portal header format (also used by OpenRouter and OpenAI-compatible +APIs that follow the same convention). + +Header schema (12 headers total): + x-ratelimit-limit-requests RPM cap + x-ratelimit-limit-requests-1h RPH cap + x-ratelimit-limit-tokens TPM cap + x-ratelimit-limit-tokens-1h TPH cap + x-ratelimit-remaining-requests requests left in minute window + x-ratelimit-remaining-requests-1h requests left in hour window + x-ratelimit-remaining-tokens tokens left in minute window + x-ratelimit-remaining-tokens-1h tokens left in hour window + x-ratelimit-reset-requests seconds until minute request window resets + x-ratelimit-reset-requests-1h seconds until hour request window resets + x-ratelimit-reset-tokens seconds until minute token window resets + x-ratelimit-reset-tokens-1h seconds until hour token window resets +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Optional + + +@dataclass +class RateLimitBucket: + """One rate-limit window (e.g. requests per minute).""" + + limit: int = 0 + remaining: int = 0 + reset_seconds: float = 0.0 + captured_at: float = 0.0 # time.time() when this was captured + + @property + def used(self) -> int: + return max(0, self.limit - self.remaining) + + @property + def usage_pct(self) -> float: + if self.limit <= 0: + return 0.0 + return (self.used / self.limit) * 100.0 + + @property + def remaining_seconds_now(self) -> float: + """Estimated seconds remaining until reset, adjusted for elapsed time.""" + elapsed = time.time() - self.captured_at + return max(0.0, self.reset_seconds - elapsed) + + +@dataclass +class RateLimitState: + """Full rate-limit state parsed from response headers.""" + + requests_min: RateLimitBucket = field(default_factory=RateLimitBucket) + requests_hour: RateLimitBucket = field(default_factory=RateLimitBucket) + tokens_min: RateLimitBucket = field(default_factory=RateLimitBucket) + tokens_hour: RateLimitBucket = field(default_factory=RateLimitBucket) + captured_at: float = 0.0 # when the headers were captured + provider: str = "" + + @property + def has_data(self) -> bool: + return self.captured_at > 0 + + @property + def age_seconds(self) -> float: + if not self.has_data: + return float("inf") + return time.time() - self.captured_at + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + return int(float(value)) + except (TypeError, ValueError): + return default + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def parse_rate_limit_headers( + headers: Mapping[str, str], + provider: str = "", +) -> Optional[RateLimitState]: + """Parse x-ratelimit-* headers into a RateLimitState. + + Returns None if no rate limit headers are present. + """ + # Quick check: at least one rate limit header must exist + has_any = any(k.lower().startswith("x-ratelimit-") for k in headers) + if not has_any: + return None + + now = time.time() + + def _bucket(resource: str, suffix: str = "") -> RateLimitBucket: + # e.g. resource="requests", suffix="" -> per-minute + # resource="tokens", suffix="-1h" -> per-hour + tag = f"{resource}{suffix}" + return RateLimitBucket( + limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")), + remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")), + reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")), + captured_at=now, + ) + + return RateLimitState( + requests_min=_bucket("requests"), + requests_hour=_bucket("requests", "-1h"), + tokens_min=_bucket("tokens"), + tokens_hour=_bucket("tokens", "-1h"), + captured_at=now, + provider=provider, + ) + + +# ── Formatting ────────────────────────────────────────────────────────── + + +def _fmt_count(n: int) -> str: + """Human-friendly number: 7999856 -> '8.0M', 33599 -> '33.6K', 799 -> '799'.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + if n >= 10_000: + return f"{n / 1_000:.1f}K" + if n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(n) + + +def _fmt_seconds(seconds: float) -> str: + """Seconds -> human-friendly duration: '58s', '2m 14s', '58m 57s', '1h 2m'.""" + s = max(0, int(seconds)) + if s < 60: + return f"{s}s" + if s < 3600: + m, sec = divmod(s, 60) + return f"{m}m {sec}s" if sec else f"{m}m" + h, remainder = divmod(s, 3600) + m = remainder // 60 + return f"{h}h {m}m" if m else f"{h}h" + + +def _bar(pct: float, width: int = 20) -> str: + """ASCII progress bar: [████████░░░░░░░░░░░░] 40%.""" + filled = int(pct / 100.0 * width) + filled = max(0, min(width, filled)) + empty = width - filled + return f"[{'█' * filled}{'░' * empty}]" + + +def _bucket_line(label: str, bucket: RateLimitBucket, label_width: int = 14) -> str: + """Format one bucket as a single line.""" + if bucket.limit <= 0: + return f" {label:<{label_width}} (no data)" + + pct = bucket.usage_pct + used = _fmt_count(bucket.used) + limit = _fmt_count(bucket.limit) + remaining = _fmt_count(bucket.remaining) + reset = _fmt_seconds(bucket.remaining_seconds_now) + + bar = _bar(pct) + return f" {label:<{label_width}} {bar} {pct:5.1f}% {used}/{limit} used ({remaining} left, resets in {reset})" + + +def format_rate_limit_display(state: RateLimitState) -> str: + """Format rate limit state for terminal/chat display.""" + if not state.has_data: + return "No rate limit data yet — make an API request first." + + age = state.age_seconds + if age < 5: + freshness = "just now" + elif age < 60: + freshness = f"{int(age)}s ago" + else: + freshness = f"{_fmt_seconds(age)} ago" + + provider_label = state.provider.title() if state.provider else "Provider" + + lines = [ + f"{provider_label} Rate Limits (captured {freshness}):", + "", + _bucket_line("Requests/min", state.requests_min), + _bucket_line("Requests/hr", state.requests_hour), + "", + _bucket_line("Tokens/min", state.tokens_min), + _bucket_line("Tokens/hr", state.tokens_hour), + ] + + # Add warnings if any bucket is getting hot + warnings = [] + for label, bucket in [ + ("requests/min", state.requests_min), + ("requests/hr", state.requests_hour), + ("tokens/min", state.tokens_min), + ("tokens/hr", state.tokens_hour), + ]: + if bucket.limit > 0 and bucket.usage_pct >= 80: + reset = _fmt_seconds(bucket.remaining_seconds_now) + warnings.append(f" ⚠ {label} at {bucket.usage_pct:.0f}% — resets in {reset}") + + if warnings: + lines.append("") + lines.extend(warnings) + + return "\n".join(lines) + + +def format_rate_limit_compact(state: RateLimitState) -> str: + """One-line compact summary for status bars / gateway messages.""" + if not state.has_data: + return "No rate limit data." + + rm = state.requests_min + tm = state.tokens_min + rh = state.requests_hour + th = state.tokens_hour + + parts = [] + if rm.limit > 0: + parts.append(f"RPM: {rm.remaining}/{rm.limit}") + if rh.limit > 0: + parts.append(f"RPH: {_fmt_count(rh.remaining)}/{_fmt_count(rh.limit)} (resets {_fmt_seconds(rh.remaining_seconds_now)})") + if tm.limit > 0: + parts.append(f"TPM: {_fmt_count(tm.remaining)}/{_fmt_count(tm.limit)}") + if th.limit > 0: + parts.append(f"TPH: {_fmt_count(th.remaining)}/{_fmt_count(th.limit)} (resets {_fmt_seconds(th.remaining_seconds_now)})") + + return " | ".join(parts) diff --git a/agent/subdirectory_hints.py b/agent/subdirectory_hints.py index 96903e2e28..dcc514b901 100644 --- a/agent/subdirectory_hints.py +++ b/agent/subdirectory_hints.py @@ -159,7 +159,10 @@ class SubdirectoryHintTracker: def _is_valid_subdir(self, path: Path) -> bool: """Check if path is a valid directory to scan for hints.""" - if not path.is_dir(): + try: + if not path.is_dir(): + return False + except OSError: return False if path in self._loaded_dirs: return False @@ -172,7 +175,10 @@ class SubdirectoryHintTracker: found_hints = [] for filename in _HINT_FILENAMES: hint_path = directory / filename - if not hint_path.is_file(): + try: + if not hint_path.is_file(): + continue + except OSError: continue try: content = hint_path.read_text(encoding="utf-8").strip() diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 14d764d7d1..d75284443f 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -117,7 +117,8 @@ terminal: timeout: 180 docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace. lifetime_seconds: 300 - # sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext! + # sudo_password: "hunter2" # Optional: pipe a sudo password via sudo -S. SECURITY WARNING: plaintext. + # sudo_password: "" # Explicit empty password: try empty and never open the interactive sudo prompt. # ----------------------------------------------------------------------------- # OPTION 2: SSH remote execution @@ -208,13 +209,18 @@ terminal: # # SECURITY WARNING: Password stored in plaintext! # -# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running, +# INTERACTIVE PROMPT: If sudo_password is unset and the CLI is running, # you'll be prompted to enter your password when sudo is needed: # - 45-second timeout (auto-skips if no input) # - Press Enter to skip (command fails gracefully) # - Password is hidden while typing # - Password is cached for the session # +# EMPTY PASSWORDS: Setting sudo_password to an explicit empty string is different +# from leaving it unset. Hermes will try an empty password via `sudo -S` and +# will not open the interactive prompt. This is useful for passwordless sudo, +# Touch ID sudo setups, and environments where prompting is just noise. +# # ALTERNATIVES: # - SSH backend: Configure passwordless sudo on the remote server # - Containers: Run as root inside the container (no sudo needed) @@ -445,6 +451,16 @@ agent: # Higher = more room for complex tasks, but costs more tokens # Recommended: 20-30 for focused tasks, 50-100 for open exploration max_turns: 60 + + # Inactivity timeout for gateway agent runs (seconds, 0 = unlimited). + # The agent can run indefinitely when actively calling tools or receiving + # API responses. Only fires after the agent has been idle for this duration. + # gateway_timeout: 1800 + + # Staged warning: send a warning before escalating to full timeout. + # Fires once per run when inactivity reaches this threshold (seconds). + # Set to 0 to disable the warning. + # gateway_timeout_warning: 900 # Enable verbose logging verbose: false diff --git a/cli.py b/cli.py index f0edf67ee2..fa32ae9119 100644 --- a/cli.py +++ b/cli.py @@ -1546,6 +1546,7 @@ class HermesCLI: self._clarify_deadline = 0 self._sudo_state = None self._sudo_deadline = 0 + self._modal_input_snapshot = None self._approval_state = None self._approval_deadline = 0 self._approval_lock = threading.Lock() @@ -5408,12 +5409,27 @@ class HermesCLI: print(f" ❌ Compression failed: {e}") def _show_usage(self): - """Show cumulative token usage for the current session.""" + """Show rate limits (if available) and session token usage.""" if not self.agent: print("(._.) No active agent -- send a message first.") return agent = self.agent + calls = agent.session_api_calls + + if calls == 0: + print("(._.) No API calls made yet in this session.") + return + + # ── Rate limits (shown first when available) ──────────────── + rl_state = agent.get_rate_limit_state() + if rl_state and rl_state.has_data: + from agent.rate_limit_tracker import format_rate_limit_display + print() + print(format_rate_limit_display(rl_state)) + print() + + # ── Session token usage ───────────────────────────────────── input_tokens = getattr(agent, "session_input_tokens", 0) or 0 output_tokens = getattr(agent, "session_output_tokens", 0) or 0 cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0 @@ -5421,13 +5437,7 @@ class HermesCLI: prompt = agent.session_prompt_tokens completion = agent.session_completion_tokens total = agent.session_total_tokens - calls = agent.session_api_calls - if calls == 0: - print("(._.) No API calls made yet in this session.") - return - - # Current context window state compressor = agent.context_compressor last_prompt = compressor.last_prompt_tokens ctx_len = compressor.context_length @@ -6205,6 +6215,7 @@ class HermesCLI: timeout = 45 response_queue = queue.Queue() + self._capture_modal_input_snapshot() self._sudo_state = { "response_queue": response_queue, } @@ -6217,6 +6228,7 @@ class HermesCLI: result = response_queue.get(timeout=1) self._sudo_state = None self._sudo_deadline = 0 + self._restore_modal_input_snapshot() self._invalidate() if result: _cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}") @@ -6231,6 +6243,7 @@ class HermesCLI: self._sudo_state = None self._sudo_deadline = 0 + self._restore_modal_input_snapshot() self._invalidate() _cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}") return "" @@ -6403,6 +6416,33 @@ class HermesCLI: def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict: return prompt_for_secret(self, var_name, prompt, metadata) + def _capture_modal_input_snapshot(self) -> None: + """Temporarily clear the input buffer and save the user's in-progress draft.""" + if self._modal_input_snapshot is not None or not getattr(self, "_app", None): + return + try: + buf = self._app.current_buffer + self._modal_input_snapshot = { + "text": buf.text, + "cursor_position": buf.cursor_position, + } + buf.reset() + except Exception: + self._modal_input_snapshot = None + + def _restore_modal_input_snapshot(self) -> None: + """Restore any draft text that was present before a modal prompt opened.""" + snapshot = self._modal_input_snapshot + self._modal_input_snapshot = None + if not snapshot or not getattr(self, "_app", None): + return + try: + buf = self._app.current_buffer + buf.text = snapshot.get("text", "") + buf.cursor_position = min(snapshot.get("cursor_position", 0), len(buf.text)) + except Exception: + pass + def _submit_secret_response(self, value: str) -> None: if not self._secret_state: return @@ -7130,6 +7170,7 @@ class HermesCLI: # Sudo password prompt state (similar mechanism to clarify) self._sudo_state = None # dict with response_queue when active self._sudo_deadline = 0 + self._modal_input_snapshot = None # Dangerous command approval state (similar mechanism to clarify) self._approval_state = None # dict with command, description, choices, selected, response_queue @@ -7201,7 +7242,6 @@ class HermesCLI: text = event.app.current_buffer.text self._sudo_state["response_queue"].put(text) self._sudo_state = None - event.app.current_buffer.reset() event.app.invalidate() return @@ -7406,7 +7446,6 @@ class HermesCLI: if self._sudo_state: self._sudo_state["response_queue"].put("") self._sudo_state = None - event.app.current_buffer.reset() event.app.invalidate() return diff --git a/cron/scheduler.py b/cron/scheduler.py index 33a9b89935..6a7f12acd6 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", - "wecom", "sms", "email", "webhook", + "wecom", "sms", "email", "webhook", "bluebubbles", }) from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -91,7 +91,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]: } # Origin missing (e.g. job created via API/script) — try each # platform's home channel as a fallback instead of silently dropping. - for platform_name in ("matrix", "telegram", "discord", "slack"): + for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"): chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "") if chat_id: logger.info( @@ -236,6 +236,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option "wecom": Platform.WECOM, "email": Platform.EMAIL, "sms": Platform.SMS, + "bluebubbles": Platform.BLUEBUBBLES, } platform = platform_map.get(platform_name.lower()) if not platform: diff --git a/flake.lock b/flake.lock index 628e492f65..78ceba92d7 100644 --- a/flake.lock +++ b/flake.lock @@ -22,16 +22,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1751274312, - "narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=", + "lastModified": 1775036866, + "narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674", + "rev": "6201e203d09599479a3b3450ed24fa81537ebc4e", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-24.11", + "ref": "nixos-unstable", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index 87be89c85c..919fa434dc 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,7 @@ description = "Hermes Agent - AI agent framework by Nous Research"; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; flake-parts = { url = "github:hercules-ci/flake-parts"; inputs.nixpkgs-lib.follows = "nixpkgs"; diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 0d12472175..022ebcae4e 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -77,7 +77,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: logger.warning("Channel directory: failed to build %s: %s", platform.value, e) # Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history - for plat_name in ("telegram", "whatsapp", "signal", "email", "sms"): + for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"): if plat_name not in platforms: platforms[plat_name] = _build_from_sessions(plat_name) diff --git a/gateway/config.py b/gateway/config.py index 047ad542f5..96ee831701 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -63,6 +63,7 @@ class Platform(Enum): WEBHOOK = "webhook" FEISHU = "feishu" WECOM = "wecom" + BLUEBUBBLES = "bluebubbles" @dataclass @@ -287,6 +288,9 @@ class GatewayConfig: # WeCom uses extra dict for bot credentials elif platform == Platform.WECOM and config.extra.get("bot_id"): connected.append(platform) + # BlueBubbles uses extra dict for local server config + elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): + connected.append(platform) return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: @@ -948,6 +952,29 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"), ) + # BlueBubbles (iMessage) + bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL") + bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD") + if bluebubbles_server_url and bluebubbles_password: + if Platform.BLUEBUBBLES not in config.platforms: + config.platforms[Platform.BLUEBUBBLES] = PlatformConfig() + config.platforms[Platform.BLUEBUBBLES].enabled = True + config.platforms[Platform.BLUEBUBBLES].extra.update({ + "server_url": bluebubbles_server_url.rstrip("/"), + "password": bluebubbles_password, + "webhook_host": os.getenv("BLUEBUBBLES_WEBHOOK_HOST", "127.0.0.1"), + "webhook_port": int(os.getenv("BLUEBUBBLES_WEBHOOK_PORT", "8645")), + "webhook_path": os.getenv("BLUEBUBBLES_WEBHOOK_PATH", "/bluebubbles-webhook"), + "send_read_receipts": os.getenv("BLUEBUBBLES_SEND_READ_RECEIPTS", "true").lower() in ("true", "1", "yes"), + }) + bluebubbles_home = os.getenv("BLUEBUBBLES_HOME_CHANNEL") + if bluebubbles_home and Platform.BLUEBUBBLES in config.platforms: + config.platforms[Platform.BLUEBUBBLES].home_channel = HomeChannel( + platform=Platform.BLUEBUBBLES, + chat_id=bluebubbles_home, + name=os.getenv("BLUEBUBBLES_HOME_CHANNEL_NAME", "Home"), + ) + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index a888eede94..bd07459ac8 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -298,6 +298,7 @@ SUPPORTED_DOCUMENT_TYPES = { ".pdf": "application/pdf", ".md": "text/markdown", ".txt": "text/plain", + ".log": "text/plain", ".zip": "application/zip", ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", @@ -407,6 +408,10 @@ class MessageEvent: # Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics) auto_skill: Optional[str] = None + # Internal flag — set for synthetic events (e.g. background process + # completion notifications) that must bypass user authorization checks. + internal: bool = False + # Timestamps timestamp: datetime = field(default_factory=datetime.now) diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py new file mode 100644 index 0000000000..83f94d3bf8 --- /dev/null +++ b/gateway/platforms/bluebubbles.py @@ -0,0 +1,828 @@ +"""BlueBubbles iMessage platform adapter. + +Uses the local BlueBubbles macOS server for outbound REST sends and inbound +webhooks. Supports text messaging, media attachments (images, voice, video, +documents), tapback reactions, typing indicators, and read receipts. + +Architecture based on PR #5869 (benjaminsehl) with inbound attachment +downloading from PR #4588 (YuhangLin). +""" + +import asyncio +import json +import logging +import os +import re +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional +from urllib.parse import quote + +import httpx + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_image_from_bytes, + cache_audio_from_bytes, + cache_document_from_bytes, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +DEFAULT_WEBHOOK_HOST = "127.0.0.1" +DEFAULT_WEBHOOK_PORT = 8645 +DEFAULT_WEBHOOK_PATH = "/bluebubbles-webhook" +MAX_TEXT_LENGTH = 4000 + +# Tapback reaction codes (BlueBubbles associatedMessageType values) +_TAPBACK_ADDED = { + 2000: "love", 2001: "like", 2002: "dislike", + 2003: "laugh", 2004: "emphasize", 2005: "question", +} +_TAPBACK_REMOVED = { + 3000: "love", 3001: "like", 3002: "dislike", + 3003: "laugh", 3004: "emphasize", 3005: "question", +} + +# Webhook event types that carry user messages +_MESSAGE_EVENTS = {"new-message", "message", "updated-message"} + +# Log redaction patterns +_PHONE_RE = re.compile(r"\+?\d{7,15}") +_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.]+") + + +def _redact(text: str) -> str: + """Redact phone numbers and emails from log output.""" + text = _PHONE_RE.sub("[REDACTED]", text) + text = _EMAIL_RE.sub("[REDACTED]", text) + return text + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def check_bluebubbles_requirements() -> bool: + try: + import aiohttp # noqa: F401 + import httpx as _httpx # noqa: F401 + except ImportError: + return False + return True + + +def _normalize_server_url(raw: str) -> str: + value = (raw or "").strip() + if not value: + return "" + if not re.match(r"^https?://", value, flags=re.I): + value = f"http://{value}" + return value.rstrip("/") + + +def _strip_markdown(text: str) -> str: + """Strip common markdown formatting for iMessage plain-text delivery.""" + text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL) + text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) + text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL) + text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL) + text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text) + text = re.sub(r"`(.+?)`", r"\1", text) + text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) + text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text) + text = re.sub(r"\n{3,}", "\n\n", text) + return text.strip() + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- + +class BlueBubblesAdapter(BasePlatformAdapter): + platform = Platform.BLUEBUBBLES + MAX_MESSAGE_LENGTH = MAX_TEXT_LENGTH + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.BLUEBUBBLES) + extra = config.extra or {} + self.server_url = _normalize_server_url( + extra.get("server_url") or os.getenv("BLUEBUBBLES_SERVER_URL", "") + ) + self.password = extra.get("password") or os.getenv("BLUEBUBBLES_PASSWORD", "") + self.webhook_host = ( + extra.get("webhook_host") + or os.getenv("BLUEBUBBLES_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST) + ) + self.webhook_port = int( + extra.get("webhook_port") + or os.getenv("BLUEBUBBLES_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT)) + ) + self.webhook_path = ( + extra.get("webhook_path") + or os.getenv("BLUEBUBBLES_WEBHOOK_PATH", DEFAULT_WEBHOOK_PATH) + ) + if not str(self.webhook_path).startswith("/"): + self.webhook_path = f"/{self.webhook_path}" + self.send_read_receipts = bool(extra.get("send_read_receipts", True)) + self.client: Optional[httpx.AsyncClient] = None + self._runner = None + self._private_api_enabled: Optional[bool] = None + self._helper_connected: bool = False + self._guid_cache: Dict[str, str] = {} + + # ------------------------------------------------------------------ + # API helpers + # ------------------------------------------------------------------ + + def _api_url(self, path: str) -> str: + sep = "&" if "?" in path else "?" + return f"{self.server_url}{path}{sep}password={quote(self.password, safe='')}" + + async def _api_get(self, path: str) -> Dict[str, Any]: + assert self.client is not None + res = await self.client.get(self._api_url(path)) + res.raise_for_status() + return res.json() + + async def _api_post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: + assert self.client is not None + res = await self.client.post(self._api_url(path), json=payload) + res.raise_for_status() + return res.json() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + if not self.server_url or not self.password: + logger.error( + "[bluebubbles] BLUEBUBBLES_SERVER_URL and BLUEBUBBLES_PASSWORD are required" + ) + return False + from aiohttp import web + + self.client = httpx.AsyncClient(timeout=30.0) + try: + await self._api_get("/api/v1/ping") + info = await self._api_get("/api/v1/server/info") + server_data = (info or {}).get("data", {}) + self._private_api_enabled = bool(server_data.get("private_api")) + self._helper_connected = bool(server_data.get("helper_connected")) + logger.info( + "[bluebubbles] connected to %s (private_api=%s, helper=%s)", + self.server_url, + self._private_api_enabled, + self._helper_connected, + ) + except Exception as exc: + logger.error( + "[bluebubbles] cannot reach server at %s: %s", self.server_url, exc + ) + if self.client: + await self.client.aclose() + self.client = None + return False + + app = web.Application() + app.router.add_get("/health", lambda _: web.Response(text="ok")) + app.router.add_post(self.webhook_path, self._handle_webhook) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, self.webhook_host, self.webhook_port) + await site.start() + self._mark_connected() + logger.info( + "[bluebubbles] webhook listening on http://%s:%s%s", + self.webhook_host, + self.webhook_port, + self.webhook_path, + ) + return True + + async def disconnect(self) -> None: + if self.client: + await self.client.aclose() + self.client = None + if self._runner: + await self._runner.cleanup() + self._runner = None + self._mark_disconnected() + + # ------------------------------------------------------------------ + # Chat GUID resolution + # ------------------------------------------------------------------ + + async def _resolve_chat_guid(self, target: str) -> Optional[str]: + """Resolve an email/phone to a BlueBubbles chat GUID. + + If *target* already contains a semicolon (raw GUID format like + ``iMessage;-;user@example.com``), it is returned as-is. Otherwise + the adapter queries the BlueBubbles chat list and matches on + ``chatIdentifier`` or participant address. + """ + target = (target or "").strip() + if not target: + return None + # Already a raw GUID + if ";" in target: + return target + if target in self._guid_cache: + return self._guid_cache[target] + try: + payload = await self._api_post( + "/api/v1/chat/query", + {"limit": 100, "offset": 0, "with": ["participants"]}, + ) + for chat in payload.get("data", []) or []: + guid = chat.get("guid") or chat.get("chatGuid") + identifier = chat.get("chatIdentifier") or chat.get("identifier") + if identifier == target: + if guid: + self._guid_cache[target] = guid + return guid + for part in chat.get("participants", []) or []: + if (part.get("address") or "").strip() == target and guid: + self._guid_cache[target] = guid + return guid + except Exception: + pass + return None + + async def _create_chat_for_handle( + self, address: str, message: str + ) -> SendResult: + """Create a new chat by sending the first message to *address*.""" + payload = { + "addresses": [address], + "message": message, + "tempGuid": f"temp-{datetime.utcnow().timestamp()}", + } + try: + res = await self._api_post("/api/v1/chat/new", payload) + data = res.get("data") or {} + msg_id = data.get("guid") or data.get("messageGuid") or "ok" + return SendResult(success=True, message_id=str(msg_id), raw_response=res) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + + # ------------------------------------------------------------------ + # Text sending + # ------------------------------------------------------------------ + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + text = _strip_markdown(content or "") + if not text: + return SendResult(success=False, error="BlueBubbles send requires text") + chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) + last = SendResult(success=True) + for chunk in chunks: + guid = await self._resolve_chat_guid(chat_id) + if not guid: + # If the target looks like an address, try creating a new chat + if self._private_api_enabled and ( + "@" in chat_id or re.match(r"^\+\d+", chat_id) + ): + return await self._create_chat_for_handle(chat_id, chunk) + return SendResult( + success=False, + error=f"BlueBubbles chat not found for target: {chat_id}", + ) + payload: Dict[str, Any] = { + "chatGuid": guid, + "tempGuid": f"temp-{datetime.utcnow().timestamp()}", + "message": chunk, + } + if reply_to and self._private_api_enabled and self._helper_connected: + payload["method"] = "private-api" + payload["selectedMessageGuid"] = reply_to + payload["partIndex"] = 0 + try: + res = await self._api_post("/api/v1/message/text", payload) + data = res.get("data") or {} + msg_id = data.get("guid") or data.get("messageGuid") or "ok" + last = SendResult( + success=True, message_id=str(msg_id), raw_response=res + ) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + return last + + # ------------------------------------------------------------------ + # Media sending (outbound) + # ------------------------------------------------------------------ + + async def _send_attachment( + self, + chat_id: str, + file_path: str, + filename: Optional[str] = None, + caption: Optional[str] = None, + is_audio_message: bool = False, + ) -> SendResult: + """Send a file attachment via BlueBubbles multipart upload.""" + if not self.client: + return SendResult(success=False, error="Not connected") + if not os.path.isfile(file_path): + return SendResult(success=False, error=f"File not found: {file_path}") + + guid = await self._resolve_chat_guid(chat_id) + if not guid: + return SendResult(success=False, error=f"Chat not found: {chat_id}") + + fname = filename or os.path.basename(file_path) + try: + with open(file_path, "rb") as f: + files = {"attachment": (fname, f, "application/octet-stream")} + data: Dict[str, str] = { + "chatGuid": guid, + "name": fname, + "tempGuid": uuid.uuid4().hex, + } + if is_audio_message: + data["isAudioMessage"] = "true" + res = await self.client.post( + self._api_url("/api/v1/message/attachment"), + files=files, + data=data, + timeout=120, + ) + res.raise_for_status() + result = res.json() + + if caption: + await self.send(chat_id, caption) + + if result.get("status") == 200: + rdata = result.get("data") or {} + msg_id = rdata.get("guid") if isinstance(rdata, dict) else None + return SendResult( + success=True, message_id=msg_id, raw_response=result + ) + return SendResult( + success=False, + error=result.get("message", "Attachment upload failed"), + ) + except Exception as e: + return SendResult(success=False, error=str(e)) + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + try: + from gateway.platforms.base import cache_image_from_url + + local_path = await cache_image_from_url(image_url) + return await self._send_attachment(chat_id, local_path, caption=caption) + except Exception: + return await super().send_image(chat_id, image_url, caption, reply_to) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment(chat_id, image_path, caption=caption) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment( + chat_id, audio_path, caption=caption, is_audio_message=True + ) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment(chat_id, video_path, caption=caption) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self._send_attachment( + chat_id, file_path, filename=file_name, caption=caption + ) + + async def send_animation( + self, + chat_id: str, + animation_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + return await self.send_image( + chat_id, animation_url, caption, reply_to, metadata + ) + + # ------------------------------------------------------------------ + # Typing indicators + # ------------------------------------------------------------------ + + async def send_typing(self, chat_id: str, metadata=None) -> None: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.post( + self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5 + ) + except Exception: + pass + + async def stop_typing(self, chat_id: str) -> None: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.delete( + self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5 + ) + except Exception: + pass + + # ------------------------------------------------------------------ + # Read receipts + # ------------------------------------------------------------------ + + async def mark_read(self, chat_id: str) -> bool: + if not self._private_api_enabled or not self._helper_connected or not self.client: + return False + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + await self.client.post( + self._api_url(f"/api/v1/chat/{encoded}/read"), timeout=5 + ) + return True + except Exception: + pass + return False + + # ------------------------------------------------------------------ + # Tapback reactions + # ------------------------------------------------------------------ + + async def send_reaction( + self, + chat_id: str, + message_guid: str, + reaction: str, + part_index: int = 0, + ) -> SendResult: + """Send a tapback reaction (requires Private API helper).""" + if not self._private_api_enabled or not self._helper_connected: + return SendResult( + success=False, error="Private API helper not connected" + ) + guid = await self._resolve_chat_guid(chat_id) + if not guid: + return SendResult(success=False, error=f"Chat not found: {chat_id}") + try: + res = await self._api_post( + "/api/v1/message/react", + { + "chatGuid": guid, + "selectedMessageGuid": message_guid, + "reaction": reaction, + "partIndex": part_index, + }, + ) + return SendResult(success=True, raw_response=res) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + + # ------------------------------------------------------------------ + # Chat info + # ------------------------------------------------------------------ + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + is_group = ";+;" in (chat_id or "") + info: Dict[str, Any] = { + "name": chat_id, + "type": "group" if is_group else "dm", + } + try: + guid = await self._resolve_chat_guid(chat_id) + if guid: + encoded = quote(guid, safe="") + res = await self._api_get( + f"/api/v1/chat/{encoded}?with=participants" + ) + data = (res or {}).get("data", {}) + display_name = ( + data.get("displayName") + or data.get("chatIdentifier") + or chat_id + ) + participants = [] + for p in data.get("participants", []) or []: + addr = (p.get("address") or "").strip() + if addr: + participants.append(addr) + info["name"] = display_name + if participants: + info["participants"] = participants + except Exception: + pass + return info + + def format_message(self, content: str) -> str: + return _strip_markdown(content) + + # ------------------------------------------------------------------ + # Inbound attachment downloading (from #4588) + # ------------------------------------------------------------------ + + async def _download_attachment( + self, att_guid: str, att_meta: Dict[str, Any] + ) -> Optional[str]: + """Download an attachment from BlueBubbles and cache it locally. + + Returns the local file path on success, None on failure. + """ + if not self.client: + return None + try: + encoded = quote(att_guid, safe="") + resp = await self.client.get( + self._api_url(f"/api/v1/attachment/{encoded}/download"), + timeout=60, + follow_redirects=True, + ) + resp.raise_for_status() + data = resp.content + + mime = (att_meta.get("mimeType") or "").lower() + transfer_name = att_meta.get("transferName", "") + + if mime.startswith("image/"): + ext_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "image/webp": ".webp", + "image/heic": ".jpg", + "image/heif": ".jpg", + "image/tiff": ".jpg", + } + ext = ext_map.get(mime, ".jpg") + return cache_image_from_bytes(data, ext) + + if mime.startswith("audio/"): + ext_map = { + "audio/mp3": ".mp3", + "audio/mpeg": ".mp3", + "audio/ogg": ".ogg", + "audio/wav": ".wav", + "audio/x-caf": ".mp3", + "audio/mp4": ".m4a", + "audio/aac": ".m4a", + } + ext = ext_map.get(mime, ".mp3") + return cache_audio_from_bytes(data, ext) + + # Videos, documents, and everything else + filename = transfer_name or f"file_{uuid.uuid4().hex[:8]}" + return cache_document_from_bytes(data, filename) + + except Exception as exc: + logger.warning( + "[bluebubbles] failed to download attachment %s: %s", + _redact(att_guid), + exc, + ) + return None + + # ------------------------------------------------------------------ + # Webhook handling + # ------------------------------------------------------------------ + + def _extract_payload_record( + self, payload: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + data = payload.get("data") + if isinstance(data, dict): + return data + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + return item + if isinstance(payload.get("message"), dict): + return payload.get("message") + return payload if isinstance(payload, dict) else None + + @staticmethod + def _value(*candidates: Any) -> Optional[str]: + for candidate in candidates: + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + return None + + async def _handle_webhook(self, request): + from aiohttp import web + + token = ( + request.query.get("password") + or request.query.get("guid") + or request.headers.get("x-password") + or request.headers.get("x-guid") + or request.headers.get("x-bluebubbles-guid") + ) + if token != self.password: + return web.json_response({"error": "unauthorized"}, status=401) + try: + raw = await request.read() + body = raw.decode("utf-8", errors="replace") + try: + payload = json.loads(body) + except Exception: + from urllib.parse import parse_qs + + form = parse_qs(body) + payload_str = ( + form.get("payload") + or form.get("data") + or form.get("message") + or [""] + )[0] + payload = json.loads(payload_str) if payload_str else {} + except Exception as exc: + logger.error("[bluebubbles] webhook parse error: %s", exc) + return web.json_response({"error": "invalid payload"}, status=400) + + event_type = self._value(payload.get("type"), payload.get("event")) or "" + # Only process message events; silently acknowledge everything else + if event_type and event_type not in _MESSAGE_EVENTS: + return web.Response(text="ok") + + record = self._extract_payload_record(payload) or {} + is_from_me = bool( + record.get("isFromMe") + or record.get("fromMe") + or record.get("is_from_me") + ) + if is_from_me: + return web.Response(text="ok") + + # Skip tapback reactions delivered as messages + assoc_type = record.get("associatedMessageType") + if isinstance(assoc_type, int) and assoc_type in { + **_TAPBACK_ADDED, + **_TAPBACK_REMOVED, + }: + return web.Response(text="ok") + + text = ( + self._value( + record.get("text"), record.get("message"), record.get("body") + ) + or "" + ) + + # --- Inbound attachment handling --- + attachments = record.get("attachments") or [] + media_urls: List[str] = [] + media_types: List[str] = [] + msg_type = MessageType.TEXT + + for att in attachments: + att_guid = att.get("guid", "") + if not att_guid: + continue + cached = await self._download_attachment(att_guid, att) + if cached: + mime = (att.get("mimeType") or "").lower() + media_urls.append(cached) + media_types.append(mime) + if mime.startswith("image/"): + msg_type = MessageType.PHOTO + elif mime.startswith("audio/") or (att.get("uti") or "").endswith( + "caf" + ): + msg_type = MessageType.VOICE + elif mime.startswith("video/"): + msg_type = MessageType.VIDEO + else: + msg_type = MessageType.DOCUMENT + + # With multiple attachments, prefer PHOTO if any images present + if len(media_urls) > 1: + mime_prefixes = {(m or "").split("/")[0] for m in media_types} + if "image" in mime_prefixes: + msg_type = MessageType.PHOTO + + if not text and media_urls: + text = "(attachment)" + # --- End attachment handling --- + + chat_guid = self._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + chat_identifier = self._value( + record.get("chatIdentifier"), + record.get("identifier"), + payload.get("chatIdentifier"), + payload.get("identifier"), + ) + sender = ( + self._value( + record.get("handle", {}).get("address") + if isinstance(record.get("handle"), dict) + else None, + record.get("sender"), + record.get("from"), + record.get("address"), + ) + or chat_identifier + or chat_guid + ) + if not (chat_guid or chat_identifier) and sender: + chat_identifier = sender + if not sender or not (chat_guid or chat_identifier) or not text: + return web.json_response({"error": "missing message fields"}, status=400) + + session_chat_id = chat_guid or chat_identifier + is_group = bool(record.get("isGroup")) or (";+;" in (chat_guid or "")) + source = self.build_source( + chat_id=session_chat_id, + chat_name=chat_identifier or sender, + chat_type="group" if is_group else "dm", + user_id=sender, + user_name=sender, + chat_id_alt=chat_identifier, + ) + event = MessageEvent( + text=text, + message_type=msg_type, + source=source, + raw_message=payload, + message_id=self._value( + record.get("guid"), + record.get("messageGuid"), + record.get("id"), + ), + reply_to_message_id=self._value( + record.get("threadOriginatorGuid"), + record.get("associatedMessageGuid"), + ), + media_urls=media_urls, + media_types=media_types, + ) + task = asyncio.create_task(self.handle_message(event)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + # Fire-and-forget read receipt + if self.send_read_receipts and session_chat_id: + asyncio.create_task(self.mark_read(session_chat_id)) + + return web.Response(text="ok") diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index b802f5712c..2ace06e779 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -1767,8 +1767,9 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(interaction.channel, "guild") and interaction.channel.guild: chat_name = f"{interaction.channel.guild.name} / #{chat_name}" - # Get channel topic (if available) - chat_topic = getattr(interaction.channel, "topic", None) + # Get channel topic (if available). + # For forum threads, inherit the parent forum's topic. + chat_topic = self._get_effective_topic(interaction.channel, is_thread=is_thread) source = self.build_source( chat_id=str(interaction.channel_id), @@ -1842,6 +1843,10 @@ class DiscordAdapter(BasePlatformAdapter): chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name + # Inherit forum topic when the thread was created inside a forum channel. + _chan = getattr(interaction, "channel", None) + chat_topic = self._get_effective_topic(_chan, is_thread=True) if _chan else None + source = self.build_source( chat_id=thread_id, chat_name=chat_name, @@ -1849,6 +1854,7 @@ class DiscordAdapter(BasePlatformAdapter): user_id=str(interaction.user.id), user_name=interaction.user.display_name, thread_id=thread_id, + chat_topic=chat_topic, ) event = MessageEvent( @@ -2134,6 +2140,15 @@ class DiscordAdapter(BasePlatformAdapter): return True return False + def _get_effective_topic(self, channel: Any, is_thread: bool = False) -> Optional[str]: + """Return the channel topic, falling back to the parent forum's topic for forum threads.""" + topic = getattr(channel, "topic", None) + if not topic and is_thread: + parent = getattr(channel, "parent", None) + if parent and self._is_forum_parent(parent): + topic = getattr(parent, "topic", None) + return topic + def _format_thread_chat_name(self, thread: Any) -> str: """Build a readable chat name for thread-like Discord channels, including forum context when available.""" thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread")) @@ -2301,8 +2316,10 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message.channel, "guild") and message.channel.guild: chat_name = f"{message.channel.guild.name} / #{chat_name}" - # Get channel topic (if available - TextChannels have topics, DMs/threads don't) - chat_topic = getattr(message.channel, "topic", None) + # Get channel topic (if available - TextChannels have topics, DMs/threads don't). + # For threads whose parent is a forum channel, inherit the parent's topic + # so forum descriptions (e.g. project instructions) appear in the session context. + chat_topic = self._get_effective_topic(message.channel, is_thread=is_thread) # Build source source = self.build_source( @@ -2365,7 +2382,7 @@ class DiscordAdapter(BasePlatformAdapter): ext or "unknown", content_type, ) else: - MAX_DOC_BYTES = 20 * 1024 * 1024 + MAX_DOC_BYTES = 32 * 1024 * 1024 if att.size and att.size > MAX_DOC_BYTES: logger.warning( "[Discord] Document too large (%s bytes), skipping: %s", @@ -2389,9 +2406,9 @@ class DiscordAdapter(BasePlatformAdapter): media_urls.append(cached_path) media_types.append(doc_mime) logger.info("[Discord] Cached user document: %s", cached_path) - # Inject text content for .txt/.md files (capped at 100 KB) + # Inject text content for plain-text documents (capped at 100 KB) MAX_TEXT_INJECT_BYTES = 100 * 1024 - if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: + if ext in (".md", ".txt", ".log") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES: try: text_content = raw_bytes.decode("utf-8") display_name = att.filename or f"document{ext}" diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 66d455ccaf..08b62f2a6d 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -647,7 +647,11 @@ class SignalAdapter(BasePlatformAdapter): if result is not None: self._track_sent_timestamp(result) - return SendResult(success=True) + # Use the timestamp from the RPC result as a pseudo message_id. + # Signal doesn't have real message IDs, but the stream consumer + # needs a truthy value to follow its edit→fallback path correctly. + _msg_id = str(result.get("timestamp", "")) if isinstance(result, dict) else None + return SendResult(success=True, message_id=_msg_id or None) return SendResult(success=False, error="RPC send failed") def _track_sent_timestamp(self, rpc_result) -> None: @@ -837,6 +841,11 @@ class SignalAdapter(BasePlatformAdapter): except asyncio.CancelledError: pass + async def stop_typing(self, chat_id: str) -> None: + """Public interface for stopping typing — called by base adapter's + _keep_typing finally block to clean up platform-level typing tasks.""" + await self._stop_typing_indicator(chat_id) + # ------------------------------------------------------------------ # Chat Info # ------------------------------------------------------------------ diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 7af313d325..26184b7eb5 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -14,7 +14,7 @@ import logging import os import re import time -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Tuple try: from slack_bolt.async_app import AsyncApp @@ -95,6 +95,12 @@ class SlackAdapter(BasePlatformAdapter): # respond to ALL subsequent messages in that thread automatically. self._mentioned_threads: set = set() self._MENTIONED_THREADS_MAX = 5000 + # Assistant thread metadata keyed by (channel_id, thread_ts). Slack's + # AI Assistant lifecycle events can arrive before/alongside message + # events, and they carry the user/thread identity needed for stable + # session + memory scoping. + self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {} + self._ASSISTANT_THREADS_MAX = 5000 async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -181,6 +187,14 @@ class SlackAdapter(BasePlatformAdapter): async def handle_app_mention(event, say): pass + @self._app.event("assistant_thread_started") + async def handle_assistant_thread_started(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + + @self._app.event("assistant_thread_context_changed") + async def handle_assistant_thread_context_changed(event, say): + await self._handle_assistant_thread_lifecycle_event(event) + # Register slash command handler @self._app.command("/hermes") async def handle_hermes_command(ack, command): @@ -755,6 +769,135 @@ class SlackAdapter(BasePlatformAdapter): # ----- Internal handlers ----- + def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]: + """Return a stable cache key for Slack assistant thread metadata.""" + if not channel_id or not thread_ts: + return None + return (str(channel_id), str(thread_ts)) + + def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]: + """Extract Slack Assistant thread identity data from an event payload.""" + assistant_thread = event.get("assistant_thread") or {} + context = assistant_thread.get("context") or event.get("context") or {} + + channel_id = ( + assistant_thread.get("channel_id") + or event.get("channel") + or context.get("channel_id") + or "" + ) + thread_ts = ( + assistant_thread.get("thread_ts") + or event.get("thread_ts") + or event.get("message_ts") + or "" + ) + user_id = ( + assistant_thread.get("user_id") + or event.get("user") + or context.get("user_id") + or "" + ) + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_thread.get("team_id") + or "" + ) + context_channel_id = context.get("channel_id") or "" + + return { + "channel_id": str(channel_id) if channel_id else "", + "thread_ts": str(thread_ts) if thread_ts else "", + "user_id": str(user_id) if user_id else "", + "team_id": str(team_id) if team_id else "", + "context_channel_id": str(context_channel_id) if context_channel_id else "", + } + + def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None: + """Remember assistant thread identity data for later message events.""" + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + key = self._assistant_thread_key(channel_id, thread_ts) + if not key: + return + + existing = self._assistant_threads.get(key, {}) + merged = dict(existing) + merged.update({k: v for k, v in metadata.items() if v}) + self._assistant_threads[key] = merged + + # Evict oldest entries when the cache exceeds the limit + if len(self._assistant_threads) > self._ASSISTANT_THREADS_MAX: + excess = len(self._assistant_threads) - self._ASSISTANT_THREADS_MAX // 2 + for old_key in list(self._assistant_threads)[:excess]: + del self._assistant_threads[old_key] + + team_id = merged.get("team_id", "") + if team_id and channel_id: + self._channel_team[channel_id] = team_id + + def _lookup_assistant_thread_metadata( + self, + event: dict, + channel_id: str = "", + thread_ts: str = "", + ) -> Dict[str, str]: + """Load cached assistant-thread metadata that matches the current event.""" + metadata = self._extract_assistant_thread_metadata(event) + if channel_id and not metadata.get("channel_id"): + metadata["channel_id"] = channel_id + if thread_ts and not metadata.get("thread_ts"): + metadata["thread_ts"] = thread_ts + + key = self._assistant_thread_key( + metadata.get("channel_id", ""), + metadata.get("thread_ts", ""), + ) + cached = self._assistant_threads.get(key, {}) if key else {} + if cached: + merged = dict(cached) + merged.update({k: v for k, v in metadata.items() if v}) + return merged + return metadata + + def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None: + """Prime the session store so assistant threads get stable user scoping.""" + session_store = getattr(self, "_session_store", None) + if not session_store: + return + + channel_id = metadata.get("channel_id", "") + thread_ts = metadata.get("thread_ts", "") + user_id = metadata.get("user_id", "") + if not channel_id or not thread_ts or not user_id: + return + + source = self.build_source( + chat_id=channel_id, + chat_name=channel_id, + chat_type="dm", + user_id=user_id, + thread_id=thread_ts, + chat_topic=metadata.get("context_channel_id") or None, + ) + + try: + session_store.get_or_create_session(source) + except Exception: + logger.debug( + "[Slack] Failed to seed assistant thread session for %s/%s", + channel_id, + thread_ts, + exc_info=True, + ) + + async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None: + """Handle Slack Assistant lifecycle events that carry user/thread identity.""" + metadata = self._extract_assistant_thread_metadata(event) + self._cache_assistant_thread_metadata(metadata) + self._seed_assistant_thread_session(metadata) + async def _handle_slack_message(self, event: dict) -> None: """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) @@ -781,10 +924,21 @@ class SlackAdapter(BasePlatformAdapter): return text = event.get("text", "") - user_id = event.get("user", "") channel_id = event.get("channel", "") ts = event.get("ts", "") - team_id = event.get("team", "") + assistant_meta = self._lookup_assistant_thread_metadata( + event, + channel_id=channel_id, + thread_ts=event.get("thread_ts", ""), + ) + user_id = event.get("user") or assistant_meta.get("user_id", "") + if not channel_id: + channel_id = assistant_meta.get("channel_id", "") + team_id = ( + event.get("team") + or event.get("team_id") + or assistant_meta.get("team_id", "") + ) # Track which workspace owns this channel if team_id and channel_id: @@ -792,6 +946,8 @@ class SlackAdapter(BasePlatformAdapter): # Determine if this is a DM or channel message channel_type = event.get("channel_type", "") + if not channel_type and channel_id.startswith("D"): + channel_type = "im" is_dm = channel_type == "im" # Build thread_ts for session keying. @@ -800,7 +956,7 @@ class SlackAdapter(BasePlatformAdapter): # In DMs: only use the real thread_ts — top-level DMs should share # one continuous session, threaded DMs get their own session. if is_dm: - thread_ts = event.get("thread_ts") # None for top-level DMs + thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs else: thread_ts = event.get("thread_ts") or ts # ts fallback for channels diff --git a/gateway/run.py b/gateway/run.py index 7a551be168..339954f5be 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -184,6 +184,8 @@ if _config_path.exists(): # Env var from .env takes precedence (already in os.environ). if "gateway_timeout" in _agent_cfg and "HERMES_AGENT_TIMEOUT" not in os.environ: os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"]) + if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ: + os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"]) # Timezone: bridge config.yaml → HERMES_TIMEZONE env var. # HERMES_TIMEZONE from .env takes precedence (already in os.environ). _tz_cfg = _cfg.get("timezone", "") @@ -1073,6 +1075,7 @@ class GatewayRunner: "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS", + "BLUEBUBBLES_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( @@ -1083,7 +1086,8 @@ class GatewayRunner: "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", - "WECOM_ALLOW_ALL_USERS") + "WECOM_ALLOW_ALL_USERS", + "BLUEBUBBLES_ALLOW_ALL_USERS") ) if not _any_allowlist and not _allow_all: logger.warning( @@ -1654,6 +1658,13 @@ class GatewayRunner: adapter.gateway_runner = self # For cross-platform delivery return adapter + elif platform == Platform.BLUEBUBBLES: + from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements + if not check_bluebubbles_requirements(): + logger.warning("BlueBubbles: aiohttp/httpx missing or BLUEBUBBLES_SERVER_URL/BLUEBUBBLES_PASSWORD not configured") + return None + return BlueBubblesAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: @@ -1692,6 +1703,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOWED_USERS", Platform.FEISHU: "FEISHU_ALLOWED_USERS", Platform.WECOM: "WECOM_ALLOWED_USERS", + Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", } platform_allow_all_map = { Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS", @@ -1706,6 +1718,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS", Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS", Platform.WECOM: "WECOM_ALLOW_ALL_USERS", + Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) @@ -1779,8 +1792,11 @@ class GatewayRunner: """ source = event.source - # Check if user is authorized - if not self._is_user_authorized(source): + # Internal events (e.g. background-process completion notifications) + # are system-generated and must skip user authorization. + if getattr(event, "internal", False): + pass + elif not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) # In DMs: offer pairing code. In groups: silently ignore. if source.chat_type == "dm" and self._get_unauthorized_dm_behavior(source.platform) == "pair": @@ -5264,19 +5280,28 @@ class GatewayRunner: agent = self._running_agents.get(session_key) if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0: - lines = [ - "📊 **Session Token Usage**", - f"Prompt (input): {agent.session_prompt_tokens:,}", - f"Completion (output): {agent.session_completion_tokens:,}", - f"Total: {agent.session_total_tokens:,}", - f"API calls: {agent.session_api_calls}", - ] + lines = [] + + # Rate limits first (when available from provider headers) + rl_state = agent.get_rate_limit_state() + if rl_state and rl_state.has_data: + from agent.rate_limit_tracker import format_rate_limit_compact + lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}") + lines.append("") + + # Session token usage + lines.append("📊 **Session Token Usage**") + lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}") + lines.append(f"Completion (output): {agent.session_completion_tokens:,}") + lines.append(f"Total: {agent.session_total_tokens:,}") + lines.append(f"API calls: {agent.session_api_calls}") ctx = agent.context_compressor if ctx.last_prompt_tokens: pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0 lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)") if ctx.compression_count: lines.append(f"Compressions: {ctx.compression_count}") + return "\n".join(lines) # No running agent -- check session history for a rough count @@ -5518,7 +5543,7 @@ class GatewayRunner: Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP, Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX, Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK, - Platform.FEISHU, Platform.WECOM, Platform.LOCAL, + Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL, }) async def _handle_update_command(self, event: MessageEvent) -> str: @@ -6158,6 +6183,7 @@ class GatewayRunner: text=synth_text, message_type=MessageType.TEXT, source=_source, + internal=True, ) logger.info( "Process %s finished — injecting agent notification for session %s", @@ -6308,7 +6334,15 @@ class GatewayRunner: # Falls back to env vars for backward compatibility. # YAML 1.1 parses bare `off` as boolean False — normalise before # the `or` chain so it doesn't silently fall through to "all". - _raw_tp = user_config.get("display", {}).get("tool_progress") + # + # Per-platform overrides (display.tool_progress_overrides) take + # priority over the global setting — e.g. Signal users can set + # tool_progress to "off" while keeping Telegram on "all". + _display_cfg = user_config.get("display", {}) + _overrides = _display_cfg.get("tool_progress_overrides", {}) + _raw_tp = _overrides.get(platform_key) + if _raw_tp is None: + _raw_tp = _display_cfg.get("tool_progress") if _raw_tp is False: _raw_tp = "off" progress_mode = ( @@ -6412,6 +6446,18 @@ class GatewayRunner: if not adapter: return + # Skip tool progress for platforms that don't support message + # editing (e.g. iMessage/BlueBubbles) — each progress update + # would become a separate message bubble, which is noisy. + from gateway.platforms.base import BasePlatformAdapter as _BaseAdapter + if type(adapter).edit_message is _BaseAdapter.edit_message: + while not progress_queue.empty(): + try: + progress_queue.get_nowait() + except Exception: + break + return + progress_lines = [] # Accumulated tool lines progress_msg_id = None # ID of the progress message to edit can_edit = True # False once an edit fails (platform doesn't support it) @@ -7106,6 +7152,9 @@ class GatewayRunner: # Default 1800s (30 min inactivity). 0 = unlimited. _agent_timeout_raw = float(os.getenv("HERMES_AGENT_TIMEOUT", 1800)) _agent_timeout = _agent_timeout_raw if _agent_timeout_raw > 0 else None + _agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + _agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None + _warning_fired = False loop = asyncio.get_event_loop() _executor_task = asyncio.ensure_future( loop.run_in_executor(None, run_sync) @@ -7138,6 +7187,25 @@ class GatewayRunner: _idle_secs = _act.get("seconds_since_activity", 0.0) except Exception: pass + # Staged warning: fire once before escalating to full timeout. + if (not _warning_fired and _agent_warning is not None + and _idle_secs >= _agent_warning): + _warning_fired = True + _warn_adapter = self.adapters.get(source.platform) + if _warn_adapter: + _elapsed_warn = int(_agent_warning // 60) or 1 + _remaining_mins = int((_agent_timeout - _agent_warning) // 60) or 1 + try: + await _warn_adapter.send( + source.chat_id, + f"⚠️ No activity for {_elapsed_warn} min. " + f"If the agent does not respond soon, it will " + f"be timed out in {_remaining_mins} min. " + f"You can continue waiting or use /reset.", + metadata=_status_thread_metadata, + ) + except Exception as _warn_err: + logger.debug("Inactivity warning send error: %s", _warn_err) if _idle_secs >= _agent_timeout: _inactivity_timeout = True break diff --git a/gateway/session.py b/gateway/session.py index 64f04ad9c9..72c3eb1618 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -193,6 +193,7 @@ _PII_SAFE_PLATFORMS = frozenset({ Platform.WHATSAPP, Platform.SIGNAL, Platform.TELEGRAM, + Platform.BLUEBUBBLES, }) """Platforms where user IDs can be safely redacted (no in-message mention system that requires raw IDs). Discord is excluded because mentions use ``<@user_id>`` diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 5522c631db..cc3d64d136 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -353,6 +353,17 @@ class GatewayStreamConsumer: self._message_id = result.message_id self._already_sent = True self._last_sent_text = text + elif result.success: + # Platform accepted the message but returned no message_id + # (e.g. Signal). Can't edit without an ID — switch to + # fallback mode: suppress intermediate deltas, send only + # the missing tail once the final response is ready. + self._already_sent = True + self._edit_supported = False + self._fallback_prefix = self._clean_for_display(text) + self._fallback_final_send = True + # Sentinel prevents re-entering this branch on every delta + self._message_id = "__no_edit__" else: # Initial send failed — disable streaming for this session self._edit_supported = False diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index 03712c272d..b29805872d 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -295,10 +295,16 @@ def _format_context_length(tokens: int) -> str: """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" if tokens >= 1_000_000: val = tokens / 1_000_000 - return f"{val:g}M" + rounded = round(val) + if abs(val - rounded) < 0.05: + return f"{rounded}M" + return f"{val:.1f}M" elif tokens >= 1_000: val = tokens / 1_000 - return f"{val:g}K" + rounded = round(val) + if abs(val - rounded) < 0.05: + return f"{rounded}K" + return f"{val:.1f}K" return str(tokens) diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 39dc4569cd..70d9cb8aa3 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -129,7 +129,7 @@ COMMAND_REGISTRY: list[CommandDef] = [ CommandDef("commands", "Browse all commands and skills (paginated)", "Info", gateway_only=True, args_hint="[page]"), CommandDef("help", "Show available commands", "Info"), - CommandDef("usage", "Show token usage for the current session", "Info"), + CommandDef("usage", "Show token usage and rate limits for the current session", "Info"), CommandDef("insights", "Show usage insights and analytics", "Info", args_hint="[days]"), CommandDef("platforms", "Show gateway/messaging platform status", "Info", diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 7c860f1593..a981b1bbbf 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -39,6 +39,7 @@ _EXTRA_ENV_KEYS = frozenset({ "DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET", "FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN", "WECOM_BOT_ID", "WECOM_SECRET", + "BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", @@ -230,6 +231,10 @@ DEFAULT_CONFIG = { # (force on/off for all models), or a list of model-name substrings # to match (e.g. ["gpt", "codex", "gemini", "qwen"]). "tool_use_enforcement": "auto", + # Staged inactivity warning: send a warning to the user at this + # threshold before escalating to a full timeout. The warning fires + # once per run and does not interrupt the agent. 0 = disable warning. + "gateway_timeout_warning": 900, }, "terminal": { @@ -392,6 +397,7 @@ DEFAULT_CONFIG = { "show_cost": False, # Show $ cost in the status bar (off by default) "skin": "default", "tool_progress_command": False, # Enable /verbose command in messaging gateway + "tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"} "tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands) }, @@ -563,7 +569,7 @@ DEFAULT_CONFIG = { }, # Config schema version - bump this when adding new required fields - "_config_version": 12, + "_config_version": 13, } # ============================================================================= @@ -1119,6 +1125,27 @@ OPTIONAL_ENV_VARS = { "category": "messaging", "advanced": True, }, + "BLUEBUBBLES_SERVER_URL": { + "description": "BlueBubbles server URL for iMessage integration (e.g. http://192.168.1.10:1234)", + "prompt": "BlueBubbles server URL", + "url": "https://bluebubbles.app/", + "password": False, + "category": "messaging", + }, + "BLUEBUBBLES_PASSWORD": { + "description": "BlueBubbles server password (from BlueBubbles Server → Settings → API)", + "prompt": "BlueBubbles server password", + "url": None, + "password": True, + "category": "messaging", + }, + "BLUEBUBBLES_ALLOWED_USERS": { + "description": "Comma-separated iMessage addresses (email or phone) allowed to use the bot", + "prompt": "Allowed iMessage addresses (comma-separated)", + "url": None, + "password": False, + "category": "messaging", + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", @@ -1190,7 +1217,7 @@ OPTIONAL_ENV_VARS = { "category": "setting", }, "SUDO_PASSWORD": { - "description": "Sudo password for terminal commands requiring root access", + "description": "Sudo password for terminal commands requiring root access; set to an explicit empty string to try empty without prompting", "prompt": "Sudo password", "url": None, "password": True, @@ -1674,6 +1701,21 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A ep = providers_dict[key] print(f" → {key}: {ep.get('api', '')}") + # ── Version 12 → 13: clear dead LLM_MODEL / OPENAI_MODEL from .env ── + # These env vars were written by the old setup wizard but nothing reads + # them anymore (config.yaml is the sole source of truth since March 2026). + # Stale entries cause user confusion — see issue report. + if current_ver < 13: + for dead_var in ("LLM_MODEL", "OPENAI_MODEL"): + try: + old_val = get_env_value(dead_var) + if old_val: + save_env_value(dead_var, "") + if not quiet: + print(f" ✓ Cleared {dead_var} from .env (no longer used — config.yaml is source of truth)") + except Exception: + pass + if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") diff --git a/hermes_cli/dump.py b/hermes_cli/dump.py new file mode 100644 index 0000000000..4ad32ca2c1 --- /dev/null +++ b/hermes_cli/dump.py @@ -0,0 +1,337 @@ +""" +Dump command for hermes CLI. + +Outputs a compact, plain-text summary of the user's Hermes setup +that can be copy-pasted into Discord/GitHub/Telegram for support context. +No ANSI colors, no checkmarks — just data. +""" + +import json +import os +import platform +import subprocess +import sys +from pathlib import Path + +from hermes_cli.config import get_hermes_home, get_env_path, get_project_root, load_config +from hermes_constants import display_hermes_home + + +def _get_git_commit(project_root: Path) -> str: + """Return short git commit hash, or '(unknown)'.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--short=8", "HEAD"], + capture_output=True, text=True, timeout=5, + cwd=str(project_root), + ) + if result.returncode == 0: + return result.stdout.strip() + except Exception: + pass + return "(unknown)" + + +def _key_present(name: str) -> str: + """Return 'set' or 'not set' for an env var.""" + return "set" if os.getenv(name) else "not set" + + +def _redact(value: str) -> str: + """Redact all but first 4 and last 4 chars.""" + if not value: + return "" + if len(value) < 12: + return "***" + return value[:4] + "..." + value[-4:] + + +def _gateway_status() -> str: + """Return a short gateway status string.""" + if sys.platform.startswith("linux"): + try: + from hermes_cli.gateway import get_service_name + svc = get_service_name() + except Exception: + svc = "hermes-gateway" + try: + r = subprocess.run( + ["systemctl", "--user", "is-active", svc], + capture_output=True, text=True, timeout=5, + ) + return "running (systemd)" if r.stdout.strip() == "active" else "stopped" + except Exception: + return "unknown" + elif sys.platform == "darwin": + try: + from hermes_cli.gateway import get_launchd_label + r = subprocess.run( + ["launchctl", "list", get_launchd_label()], + capture_output=True, text=True, timeout=5, + ) + return "loaded (launchd)" if r.returncode == 0 else "not loaded" + except Exception: + return "unknown" + return "N/A" + + +def _count_skills(hermes_home: Path) -> int: + """Count installed skills.""" + skills_dir = hermes_home / "skills" + if not skills_dir.is_dir(): + return 0 + count = 0 + for item in skills_dir.rglob("SKILL.md"): + count += 1 + return count + + +def _count_mcp_servers(config: dict) -> int: + """Count configured MCP servers.""" + mcp = config.get("mcp", {}) + servers = mcp.get("servers", {}) + return len(servers) + + +def _cron_summary(hermes_home: Path) -> str: + """Return cron jobs summary.""" + jobs_file = hermes_home / "cron" / "jobs.json" + if not jobs_file.exists(): + return "0" + try: + with open(jobs_file, encoding="utf-8") as f: + data = json.load(f) + jobs = data.get("jobs", []) + active = sum(1 for j in jobs if j.get("enabled", True)) + return f"{active} active / {len(jobs)} total" + except Exception: + return "(error reading)" + + +def _configured_platforms() -> list[str]: + """Return list of configured messaging platform names.""" + checks = { + "telegram": "TELEGRAM_BOT_TOKEN", + "discord": "DISCORD_BOT_TOKEN", + "slack": "SLACK_BOT_TOKEN", + "whatsapp": "WHATSAPP_ENABLED", + "signal": "SIGNAL_HTTP_URL", + "email": "EMAIL_ADDRESS", + "sms": "TWILIO_ACCOUNT_SID", + "matrix": "MATRIX_HOMESERVER_URL", + "mattermost": "MATTERMOST_URL", + "homeassistant": "HASS_TOKEN", + "dingtalk": "DINGTALK_CLIENT_ID", + "feishu": "FEISHU_APP_ID", + "wecom": "WECOM_BOT_ID", + } + return [name for name, env in checks.items() if os.getenv(env)] + + +def _memory_provider(config: dict) -> str: + """Return the active memory provider name.""" + mem = config.get("memory", {}) + provider = mem.get("provider", "") + return provider if provider else "built-in" + + +def _get_model_and_provider(config: dict) -> tuple[str, str]: + """Extract model and provider from config.""" + model_cfg = config.get("model", "") + if isinstance(model_cfg, dict): + model = model_cfg.get("default") or model_cfg.get("model") or model_cfg.get("name") or "(not set)" + provider = model_cfg.get("provider") or "(auto)" + elif isinstance(model_cfg, str): + model = model_cfg or "(not set)" + provider = "(auto)" + else: + model = "(not set)" + provider = "(auto)" + return model, provider + + +def _config_overrides(config: dict) -> dict[str, str]: + """Find non-default config values worth reporting. + + Returns a flat dict of dotpath -> value for interesting overrides. + """ + from hermes_cli.config import DEFAULT_CONFIG + + overrides = {} + + # Sections with interesting user-facing overrides + interesting_paths = [ + ("agent", "max_turns"), + ("agent", "gateway_timeout"), + ("agent", "tool_use_enforcement"), + ("terminal", "backend"), + ("terminal", "docker_image"), + ("terminal", "persistent_shell"), + ("browser", "allow_private_urls"), + ("compression", "enabled"), + ("compression", "threshold"), + ("display", "streaming"), + ("display", "skin"), + ("display", "show_reasoning"), + ("smart_model_routing", "enabled"), + ("privacy", "redact_pii"), + ("tts", "provider"), + ] + + for section, key in interesting_paths: + default_section = DEFAULT_CONFIG.get(section, {}) + user_section = config.get(section, {}) + if not isinstance(default_section, dict) or not isinstance(user_section, dict): + continue + default_val = default_section.get(key) + user_val = user_section.get(key) + if user_val is not None and user_val != default_val: + overrides[f"{section}.{key}"] = str(user_val) + + # Toolsets (if different from default) + default_toolsets = DEFAULT_CONFIG.get("toolsets", []) + user_toolsets = config.get("toolsets", []) + if user_toolsets != default_toolsets: + overrides["toolsets"] = str(user_toolsets) + + # Fallback providers + fallbacks = config.get("fallback_providers", []) + if fallbacks: + overrides["fallback_providers"] = str(fallbacks) + + return overrides + + +def run_dump(args): + """Output a compact, copy-pasteable setup summary.""" + show_keys = getattr(args, "show_keys", False) + + # Load env from .env file so key checks work + from dotenv import load_dotenv + env_path = get_env_path() + if env_path.exists(): + try: + load_dotenv(env_path, encoding="utf-8") + except UnicodeDecodeError: + load_dotenv(env_path, encoding="latin-1") + # Also try project .env as dev fallback + load_dotenv(get_project_root() / ".env", override=False, encoding="utf-8") + + project_root = get_project_root() + hermes_home = get_hermes_home() + + try: + from hermes_cli import __version__, __release_date__ + except ImportError: + __version__ = "(unknown)" + __release_date__ = "" + + commit = _get_git_commit(project_root) + + try: + config = load_config() + except Exception: + config = {} + + model, provider = _get_model_and_provider(config) + + # Profile + try: + from hermes_cli.profiles import get_active_profile_name + profile = get_active_profile_name() or "(default)" + except Exception: + profile = "(default)" + + # Terminal backend + terminal_cfg = config.get("terminal", {}) + backend = terminal_cfg.get("backend", "local") + + # OpenAI SDK version + try: + import openai + openai_ver = openai.__version__ + except ImportError: + openai_ver = "not installed" + + # OS info + os_info = f"{platform.system()} {platform.release()} {platform.machine()}" + + lines = [] + lines.append("--- hermes dump ---") + ver_str = f"{__version__}" + if __release_date__: + ver_str += f" ({__release_date__})" + ver_str += f" [{commit}]" + lines.append(f"version: {ver_str}") + lines.append(f"os: {os_info}") + lines.append(f"python: {sys.version.split()[0]}") + lines.append(f"openai_sdk: {openai_ver}") + lines.append(f"profile: {profile}") + lines.append(f"hermes_home: {display_hermes_home()}") + lines.append(f"model: {model}") + lines.append(f"provider: {provider}") + lines.append(f"terminal: {backend}") + + # API keys + lines.append("") + lines.append("api_keys:") + api_keys = [ + ("OPENROUTER_API_KEY", "openrouter"), + ("OPENAI_API_KEY", "openai"), + ("ANTHROPIC_API_KEY", "anthropic"), + ("ANTHROPIC_TOKEN", "anthropic_token"), + ("NOUS_API_KEY", "nous"), + ("GLM_API_KEY", "glm/zai"), + ("ZAI_API_KEY", "zai"), + ("KIMI_API_KEY", "kimi"), + ("MINIMAX_API_KEY", "minimax"), + ("DEEPSEEK_API_KEY", "deepseek"), + ("DASHSCOPE_API_KEY", "dashscope"), + ("HF_TOKEN", "huggingface"), + ("AI_GATEWAY_API_KEY", "ai_gateway"), + ("OPENCODE_ZEN_API_KEY", "opencode_zen"), + ("OPENCODE_GO_API_KEY", "opencode_go"), + ("KILOCODE_API_KEY", "kilocode"), + ("FIRECRAWL_API_KEY", "firecrawl"), + ("TAVILY_API_KEY", "tavily"), + ("BROWSERBASE_API_KEY", "browserbase"), + ("FAL_KEY", "fal"), + ("ELEVENLABS_API_KEY", "elevenlabs"), + ("GITHUB_TOKEN", "github"), + ] + + for env_var, label in api_keys: + val = os.getenv(env_var, "") + if show_keys and val: + display = _redact(val) + else: + display = "set" if val else "not set" + lines.append(f" {label:<20} {display}") + + # Features summary + lines.append("") + lines.append("features:") + + toolsets = config.get("toolsets", ["hermes-cli"]) + lines.append(f" toolsets: {', '.join(toolsets) if toolsets else '(default)'}") + lines.append(f" mcp_servers: {_count_mcp_servers(config)}") + lines.append(f" memory_provider: {_memory_provider(config)}") + lines.append(f" gateway: {_gateway_status()}") + + platforms = _configured_platforms() + lines.append(f" platforms: {', '.join(platforms) if platforms else 'none'}") + lines.append(f" cron_jobs: {_cron_summary(hermes_home)}") + lines.append(f" skills: {_count_skills(hermes_home)}") + + # Config overrides (non-default values) + overrides = _config_overrides(config) + if overrides: + lines.append("") + lines.append("config_overrides:") + for key, val in overrides.items(): + lines.append(f" {key}: {val}") + + lines.append("--- end dump ---") + + output = "\n".join(lines) + print(output) diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 89b01b18c5..82689f8fff 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -1588,6 +1588,34 @@ _PLATFORMS = [ "help": "Chat ID for scheduled results and notifications."}, ], }, + { + "key": "bluebubbles", + "label": "BlueBubbles (iMessage)", + "emoji": "💬", + "token_var": "BLUEBUBBLES_SERVER_URL", + "setup_instructions": [ + "1. Install BlueBubbles on a Mac that will act as your iMessage server:", + " https://bluebubbles.app/", + "2. Complete the BlueBubbles setup wizard — sign in with your Apple ID", + "3. In BlueBubbles Settings → API, note the Server URL and password", + "4. The server URL is typically http://:1234", + "5. Hermes connects via the BlueBubbles REST API and receives", + " incoming messages via a local webhook", + "6. To authorize users, use DM pairing: hermes pairing generate bluebubbles", + " Share the code — the user sends it via iMessage to get approved", + ], + "vars": [ + {"name": "BLUEBUBBLES_SERVER_URL", "prompt": "BlueBubbles server URL (e.g. http://192.168.1.10:1234)", "password": False, + "help": "The URL shown in BlueBubbles Settings → API."}, + {"name": "BLUEBUBBLES_PASSWORD", "prompt": "BlueBubbles server password", "password": True, + "help": "The password shown in BlueBubbles Settings → API."}, + {"name": "BLUEBUBBLES_ALLOWED_USERS", "prompt": "Pre-authorized phone numbers or iMessage IDs (comma-separated, or leave empty for DM pairing)", "password": False, + "is_allowlist": True, + "help": "Optional — pre-authorize specific users. Leave empty to use DM pairing instead (recommended)."}, + {"name": "BLUEBUBBLES_HOME_CHANNEL", "prompt": "Home channel (phone number or iMessage ID for cron/notifications, or empty)", "password": False, + "help": "Phone number or Apple ID to deliver cron results and notifications to."}, + ], + }, ] diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 35dc605f92..c838639ba6 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1585,7 +1585,11 @@ def _model_flow_custom(config): f"Hermes will still save it." ) if probe.get("suggested_base_url"): - print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}") + suggested = probe["suggested_base_url"] + if suggested.endswith("/v1"): + print(f" If this server expects /v1 in the path, try base URL: {suggested}") + else: + print(f" If /v1 should not be in the base URL, try: {suggested}") # Select model — use probe results when available, fall back to manual input model_name = "" @@ -2750,6 +2754,12 @@ def cmd_doctor(args): run_doctor(args) +def cmd_dump(args): + """Dump setup summary for support/debugging.""" + from hermes_cli.dump import run_dump + run_dump(args) + + def cmd_config(args): """Configuration management.""" from hermes_cli.config import config_command @@ -4843,6 +4853,22 @@ For more help on a command: help="Attempt to fix issues automatically" ) doctor_parser.set_defaults(func=cmd_doctor) + + # ========================================================================= + # dump command + # ========================================================================= + dump_parser = subparsers.add_parser( + "dump", + help="Dump setup summary for support/debugging", + description="Output a compact, plain-text summary of your Hermes setup " + "that can be copy-pasted into Discord/GitHub for support context" + ) + dump_parser.add_argument( + "--show-keys", + action="store_true", + help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set" + ) + dump_parser.set_defaults(func=cmd_dump) # ========================================================================= # config command diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index 07efbcf4a6..7d120d94f1 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -537,8 +537,11 @@ def switch_model( ) else: # --- Step c: On aggregator, convert vendor:model to vendor/model --- + # Only convert when there's no slash — a slash means the name + # is already in vendor/model format and the colon is a variant + # tag (:free, :extended, :fast) that must be preserved. colon_pos = raw_input.find(":") - if colon_pos > 0 and is_aggregator(current_provider): + if colon_pos > 0 and "/" not in raw_input and is_aggregator(current_provider): left = raw_input[:colon_pos].strip().lower() right = raw_input[colon_pos + 1:].strip() if left and right: diff --git a/hermes_cli/models.py b/hermes_cli/models.py index ce89bdeac0..b55249a70c 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -1532,7 +1532,7 @@ def probe_api_models( return { "models": None, - "probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models", + "probed_url": tried[0] if tried else normalized.rstrip("/") + "/models", "resolved_base_url": normalized, "suggested_base_url": alternate_base if alternate_base != normalized else None, "used_fallback": False, diff --git a/hermes_cli/profiles.py b/hermes_cli/profiles.py index 48ecbc4ca4..9be25e1007 100644 --- a/hermes_cli/profiles.py +++ b/hermes_cli/profiles.py @@ -102,7 +102,7 @@ _RESERVED_NAMES = frozenset({ # Hermes subcommands that cannot be used as profile names/aliases _HERMES_SUBCOMMANDS = frozenset({ "chat", "model", "gateway", "setup", "whatsapp", "login", "logout", - "status", "cron", "doctor", "config", "pairing", "skills", "tools", + "status", "cron", "doctor", "dump", "config", "pairing", "skills", "tools", "mcp", "sessions", "insights", "version", "update", "uninstall", "profile", "plugins", "honcho", "acp", }) @@ -1007,7 +1007,7 @@ _hermes_completion() { # Top-level subcommands if [[ "$COMP_CWORD" == 1 ]]; then - local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version" + local commands="chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version" COMPREPLY=($(compgen -W "$commands" -- "$cur")) fi } @@ -1032,7 +1032,7 @@ _hermes() { _arguments \\ '-p[Profile name]:profile:($profiles)' \\ '--profile[Profile name]:profile:($profiles)' \\ - '1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\ + '1:command:(chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version)' \\ '*::arg:->args' case $words[1] in diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 43c3b086d9..95c9fa6228 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -2167,6 +2167,71 @@ def _setup_whatsapp(): print_info("or personal self-chat) and pair via QR code.") +def _setup_bluebubbles(): + """Configure BlueBubbles iMessage gateway.""" + print_header("BlueBubbles (iMessage)") + existing = get_env_value("BLUEBUBBLES_SERVER_URL") + if existing: + print_info("BlueBubbles: already configured") + if not prompt_yes_no("Reconfigure BlueBubbles?", False): + return + + print_info("Connects Hermes to iMessage via BlueBubbles — a free, open-source") + print_info("macOS server that bridges iMessage to any device.") + print_info(" Requires a Mac running BlueBubbles Server v1.0.0+") + print_info(" Download: https://bluebubbles.app/") + print() + print_info("In BlueBubbles Server → Settings → API, note your Server URL and Password.") + print() + + server_url = prompt("BlueBubbles server URL (e.g. http://192.168.1.10:1234)") + if not server_url: + print_warning("Server URL is required — skipping BlueBubbles setup") + return + save_env_value("BLUEBUBBLES_SERVER_URL", server_url.rstrip("/")) + + password = prompt("BlueBubbles server password", password=True) + if not password: + print_warning("Password is required — skipping BlueBubbles setup") + return + save_env_value("BLUEBUBBLES_PASSWORD", password) + print_success("BlueBubbles credentials saved") + + print() + print_info("🔒 Security: Restrict who can message your bot") + print_info(" Use iMessage addresses: email (user@icloud.com) or phone (+15551234567)") + print() + allowed_users = prompt("Allowed iMessage addresses (comma-separated, leave empty for open access)") + if allowed_users: + save_env_value("BLUEBUBBLES_ALLOWED_USERS", allowed_users.replace(" ", "")) + print_success("BlueBubbles allowlist configured") + else: + print_info("⚠️ No allowlist set — anyone who can iMessage you can use the bot!") + + print() + print_info("📬 Home Channel: phone or email for cron job delivery and notifications.") + print_info(" You can also set this later with /set-home in your iMessage chat.") + home_channel = prompt("Home channel address (leave empty to set later)") + if home_channel: + save_env_value("BLUEBUBBLES_HOME_CHANNEL", home_channel) + + print() + print_info("Advanced settings (defaults are fine for most setups):") + if prompt_yes_no("Configure webhook listener settings?", False): + webhook_port = prompt("Webhook listener port (default: 8645)") + if webhook_port: + try: + save_env_value("BLUEBUBBLES_WEBHOOK_PORT", str(int(webhook_port))) + print_success(f"Webhook port set to {webhook_port}") + except ValueError: + print_warning("Invalid port number, using default 8645") + + print() + print_info("Requires the BlueBubbles Private API helper for typing indicators,") + print_info("read receipts, and tapback reactions. Basic messaging works without it.") + print_info(" Install: https://docs.bluebubbles.app/helper-bundle/installation") + + def _setup_webhooks(): """Configure webhook integration.""" print_header("Webhooks") @@ -2221,6 +2286,7 @@ _GATEWAY_PLATFORMS = [ ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), + ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), ] @@ -2264,6 +2330,7 @@ def setup_gateway(config: dict): or get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") + or get_env_value("BLUEBUBBLES_SERVER_URL") or get_env_value("WEBHOOK_ENABLED") ) if any_messaging: @@ -2283,6 +2350,8 @@ def setup_gateway(config: dict): missing_home.append("Discord") if get_env_value("SLACK_BOT_TOKEN") and not get_env_value("SLACK_HOME_CHANNEL"): missing_home.append("Slack") + if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"): + missing_home.append("BlueBubbles") if missing_home: print() @@ -2453,6 +2522,8 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] platforms.append("WhatsApp") if get_env_value("SIGNAL_ACCOUNT"): platforms.append("Signal") + if get_env_value("BLUEBUBBLES_SERVER_URL"): + platforms.append("BlueBubbles") if platforms: return ", ".join(platforms) return None # No platforms configured — section must run diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index 7b44014ea5..d7e47ca5f2 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -23,6 +23,7 @@ PLATFORMS = { "slack": "💼 Slack", "whatsapp": "📱 WhatsApp", "signal": "📡 Signal", + "bluebubbles": "💬 BlueBubbles", "email": "📧 Email", "homeassistant": "🏠 Home Assistant", "mattermost": "💬 Mattermost", diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 6fe8f7df0b..eed89885d2 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -302,6 +302,7 @@ def show_status(args): "DingTalk": ("DINGTALK_CLIENT_ID", None), "Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"), "WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"), + "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 65525d27d0..9a50a2c5d5 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -126,6 +126,7 @@ PLATFORMS = { "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, + "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"}, "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, diff --git a/hermes_state.py b/hermes_state.py index da632a9e11..a845dbb9f9 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -1235,10 +1235,10 @@ class SessionDB: self._execute_write(_do) def delete_session(self, session_id: str) -> bool: - """Delete a session, its child sessions, and all their messages. + """Delete a session and all its messages. - Child sessions (subagent runs, compression continuations) are deleted - first to satisfy the ``parent_session_id`` foreign key constraint. + Child sessions are orphaned (parent_session_id set to NULL) rather + than cascade-deleted, so they remain accessible independently. Returns True if the session was found and deleted. """ def _do(conn): @@ -1247,15 +1247,12 @@ class SessionDB: ) if cursor.fetchone()[0] == 0: return False - # Delete child sessions first (FK constraint) - child_ids = [r[0] for r in conn.execute( - "SELECT id FROM sessions WHERE parent_session_id = ?", + # Orphan child sessions so FK constraint is satisfied + conn.execute( + "UPDATE sessions SET parent_session_id = NULL " + "WHERE parent_session_id = ?", (session_id,), - ).fetchall()] - for cid in child_ids: - conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (cid,)) - # Delete the session itself + ) conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) return True @@ -1264,9 +1261,9 @@ class SessionDB: def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: """Delete sessions older than N days. Returns count of deleted sessions. - Only prunes ended sessions (not active ones). Child sessions whose - parents are being pruned are deleted first to satisfy the - ``parent_session_id`` foreign key constraint. + Only prunes ended sessions (not active ones). Child sessions outside + the prune window are orphaned (parent_session_id set to NULL) rather + than cascade-deleted. """ cutoff = time.time() - (older_than_days * 86400) @@ -1284,17 +1281,16 @@ class SessionDB: ) session_ids = set(row["id"] for row in cursor.fetchall()) - # Delete children first whose parents are in the prune set - # (avoids FK constraint errors) - for sid in list(session_ids): - child_ids = [r[0] for r in conn.execute( - "SELECT id FROM sessions WHERE parent_session_id = ?", - (sid,), - ).fetchall()] - for cid in child_ids: - conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,)) - conn.execute("DELETE FROM sessions WHERE id = ?", (cid,)) - session_ids.discard(cid) # don't double-delete + if not session_ids: + return 0 + + # Orphan any sessions whose parent is about to be deleted + placeholders = ",".join("?" * len(session_ids)) + conn.execute( + f"UPDATE sessions SET parent_session_id = NULL " + f"WHERE parent_session_id IN ({placeholders})", + list(session_ids), + ) for sid in session_ids: conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) diff --git a/nix/nixosModules.nix b/nix/nixosModules.nix index c961aa616a..948f7df8c5 100644 --- a/nix/nixosModules.nix +++ b/nix/nixosModules.nix @@ -569,7 +569,7 @@ # ── Activation: link config + auth + documents ──────────────────── { - system.activationScripts."hermes-agent-setup" = lib.stringAfter [ "users" "setupSecrets" ] '' + system.activationScripts."hermes-agent-setup" = lib.stringAfter ([ "users" ] ++ lib.optional (config.system.activationScripts ? setupSecrets) "setupSecrets") '' # Ensure directories exist (activation runs before tmpfiles) mkdir -p ${cfg.stateDir}/.hermes mkdir -p ${cfg.stateDir}/home diff --git a/nix/packages.nix b/nix/packages.nix index 9a65b889d3..eb50d4a17b 100644 --- a/nix/packages.nix +++ b/nix/packages.nix @@ -14,7 +14,7 @@ }; runtimeDeps = with pkgs; [ - nodejs_20 ripgrep git openssh ffmpeg + nodejs_20 ripgrep git openssh ffmpeg tirith ]; runtimePath = pkgs.lib.makeBinPath runtimeDeps; diff --git a/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py b/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py index 74e9d7dac3..5e0f76db28 100644 --- a/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py +++ b/optional-skills/migration/openclaw-migration/scripts/openclaw_to_hermes.py @@ -1803,30 +1803,34 @@ class Migrator: def migrate_cron_jobs(self, config: Optional[Dict[str, Any]] = None) -> None: config = config or self.load_openclaw_config() cron = config.get("cron") or {} - if not cron: - self.record("cron-jobs", None, None, "skipped", "No cron configuration found") - return - - # Archive the full cron config - if self.archive_dir and self.execute: - self.archive_dir.mkdir(parents=True, exist_ok=True) - dest = self.archive_dir / "cron-config.json" - dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") - self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived", - "Cron config archived. Use 'hermes cron' to recreate jobs manually.") - else: - self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json", - "archived", "Would archive cron config") - - # Also check for cron store files cron_store = self.source_root / "cron" + found_any = False + + # Archive the full cron config when present + if cron: + found_any = True + if self.archive_dir and self.execute: + self.archive_dir.mkdir(parents=True, exist_ok=True) + dest = self.archive_dir / "cron-config.json" + dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived", + "Cron config archived. Use 'hermes cron' to recreate jobs manually.") + else: + self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json", + "archived", "Would archive cron config") + + # Also check for cron store files even when config.cron is missing if cron_store.is_dir() and self.archive_dir: + found_any = True dest_cron = self.archive_dir / "cron-store" if self.execute: shutil.copytree(cron_store, dest_cron, dirs_exist_ok=True) self.record("cron-jobs", str(cron_store), str(dest_cron), "archived", "Cron job store archived") + if not found_any: + self.record("cron-jobs", None, None, "skipped", "No cron configuration found") + # ── Hooks ───────────────────────────────────────────────── def migrate_hooks_config(self, config: Optional[Dict[str, Any]] = None) -> None: config = config or self.load_openclaw_config() @@ -2454,6 +2458,15 @@ class Migrator: notes.append(f"- **{item.kind}**: {item.reason}") notes.append("") + has_cron_config_archive = any( + i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-config.json") + for i in self.items + ) + has_cron_store_archive = any( + i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-store") + for i in self.items + ) + notes.extend([ "## IMPORTANT: Archive the OpenClaw Directory", "", @@ -2475,7 +2488,14 @@ class Migrator: "- Run `hermes claw cleanup` to archive the OpenClaw directory (prevents state confusion)", "- Run `hermes setup` to configure any remaining settings", "- Run `hermes mcp list` to verify MCP servers were imported correctly", - "- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)", + ]) + + if has_cron_config_archive: + notes.append("- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)") + elif has_cron_store_archive: + notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)") + + notes.extend([ "- Run `hermes gateway install` if you need the gateway service", "- Review `~/.hermes/config.yaml` for any adjustments", "", diff --git a/plugins/memory/hindsight/README.md b/plugins/memory/hindsight/README.md index 3a1df59e4d..024a993031 100644 --- a/plugins/memory/hindsight/README.md +++ b/plugins/memory/hindsight/README.md @@ -1,11 +1,12 @@ # Hindsight Memory Provider -Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud and local (embedded) modes. +Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud, local embedded, and local external modes. ## Requirements - **Cloud:** API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io) -- **Local:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, MiniMax, or Ollama). Embeddings and reranking run locally — no additional API keys needed. +- **Local Embedded:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, OpenRouter, MiniMax, Ollama, or any OpenAI-compatible endpoint). Embeddings and reranking run locally — no additional API keys needed. +- **Local External:** A running Hindsight instance (Docker or self-hosted) reachable over HTTP. ## Setup @@ -21,17 +22,28 @@ hermes config set memory.provider hindsight echo "HINDSIGHT_API_KEY=your-key" >> ~/.hermes/.env ``` -### Cloud Mode +### Cloud Connects to the Hindsight Cloud API. Requires an API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io). -### Local Mode +### Local Embedded -Runs an embedded Hindsight server with built-in PostgreSQL. Requires an LLM API key (e.g. Groq, OpenAI, Anthropic) for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity. +Hermes spins up a local Hindsight daemon with built-in PostgreSQL. Requires an LLM API key for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity. + +Supports any OpenAI-compatible LLM endpoint (llama.cpp, vLLM, LM Studio, etc.) — pick `openai_compatible` as the provider and enter the base URL. Daemon startup logs: `~/.hermes/logs/hindsight-embed.log` Daemon runtime logs: `~/.hindsight/profiles/.log` +To open the Hindsight web UI (local embedded mode only): +```bash +hindsight-embed -p hermes ui start +``` + +### Local External + +Points the plugin at an existing Hindsight instance you're already running (Docker, self-hosted, etc.). No daemon management — just a URL and an optional API key. + ## Config Config file: `~/.hermes/hindsight/config.json` @@ -40,40 +52,58 @@ Config file: `~/.hermes/hindsight/config.json` | Key | Default | Description | |-----|---------|-------------| -| `mode` | `cloud` | `cloud` or `local` | -| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud mode) | -| `api_url` | `http://localhost:8888` | API URL (local mode, unused — daemon manages its own port) | +| `mode` | `cloud` | `cloud`, `local_embedded`, or `local_external` | +| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud and local_external modes) | -### Memory +### Memory Bank | Key | Default | Description | |-----|---------|-------------| | `bank_id` | `hermes` | Memory bank name | -| `budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `bank_mission` | — | Reflect mission (identity/framing for reflect reasoning). Applied via Banks API. | +| `bank_retain_mission` | — | Retain mission (steers what gets extracted). Applied via Banks API. | + +### Recall + +| Key | Default | Description | +|-----|---------|-------------| +| `recall_budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` | +| `recall_prefetch_method` | `recall` | Auto-recall method: `recall` (raw facts) or `reflect` (LLM synthesis) | +| `recall_max_tokens` | `4096` | Maximum tokens for recall results | +| `recall_max_input_chars` | `800` | Maximum input query length for auto-recall | +| `recall_prompt_preamble` | — | Custom preamble for recalled memories in context | +| `recall_tags` | — | Tags to filter when searching memories | +| `recall_tags_match` | `any` | Tag matching mode: `any` / `all` / `any_strict` / `all_strict` | +| `auto_recall` | `true` | Automatically recall memories before each turn | + +### Retain + +| Key | Default | Description | +|-----|---------|-------------| +| `auto_retain` | `true` | Automatically retain conversation turns | +| `retain_async` | `true` | Process retain asynchronously on the Hindsight server | +| `retain_every_n_turns` | `1` | Retain every N turns (1 = every turn) | +| `retain_context` | `conversation between Hermes Agent and the User` | Context label for retained memories | +| `tags` | — | Tags applied when storing memories | ### Integration | Key | Default | Description | |-----|---------|-------------| | `memory_mode` | `hybrid` | How memories are integrated into the agent | -| `prefetch_method` | `recall` | Method for automatic context injection | **memory_mode:** - `hybrid` — automatic context injection + tools available to the LLM - `context` — automatic injection only, no tools exposed - `tools` — tools only, no automatic injection -**prefetch_method:** -- `recall` — injects raw memory facts (fast) -- `reflect` — injects LLM-synthesized summary (slower, more coherent) - -### Local Mode LLM +### Local Embedded LLM | Key | Default | Description | |-----|---------|-------------| -| `llm_provider` | `openai` | LLM provider: `openai`, `anthropic`, `gemini`, `groq`, `minimax`, `ollama` | -| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `openai/gpt-oss-120b`) | -| `llm_base_url` | — | LLM Base URL override (e.g. `https://openrouter.ai/api/v1`) | +| `llm_provider` | `openai` | `openai`, `anthropic`, `gemini`, `groq`, `openrouter`, `minimax`, `ollama`, `lmstudio`, `openai_compatible` | +| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `qwen/qwen3.5-9b`) | +| `llm_base_url` | — | Endpoint URL for `openai_compatible` (e.g. `http://192.168.1.10:8080/v1`) | The LLM API key is stored in `~/.hermes/.env` as `HINDSIGHT_LLM_API_KEY`. @@ -97,4 +127,8 @@ Available in `hybrid` and `tools` memory modes: | `HINDSIGHT_API_URL` | Override API endpoint | | `HINDSIGHT_BANK_ID` | Override bank name | | `HINDSIGHT_BUDGET` | Override recall budget | -| `HINDSIGHT_MODE` | Override mode (`cloud` / `local`) | +| `HINDSIGHT_MODE` | Override mode (`cloud`, `local_embedded`, `local_external`) | + +## Client Version + +Requires `hindsight-client >= 0.4.22`. The plugin auto-upgrades on session start if an older version is detected. diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index c87497745e..c39679b73c 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -28,21 +28,25 @@ from hermes_constants import get_hermes_home from typing import Any, Dict, List from agent.memory_provider import MemoryProvider +from hermes_constants import get_hermes_home from tools.registry import tool_error logger = logging.getLogger(__name__) _DEFAULT_API_URL = "https://api.hindsight.vectorize.io" _DEFAULT_LOCAL_URL = "http://localhost:8888" +_MIN_CLIENT_VERSION = "0.4.22" _VALID_BUDGETS = {"low", "mid", "high"} _PROVIDER_DEFAULT_MODELS = { "openai": "gpt-4o-mini", "anthropic": "claude-haiku-4-5", "gemini": "gemini-2.5-flash", "groq": "openai/gpt-oss-120b", + "openrouter": "qwen/qwen3.5-9b", "minimax": "MiniMax-M2.7", "ollama": "gemma3:12b", "lmstudio": "local-model", + "openai_compatible": "your-model-name", } @@ -188,6 +192,7 @@ class HindsightMemoryProvider(MemoryProvider): self._bank_id = "hermes" self._budget = "mid" self._mode = "cloud" + self._llm_base_url = "" self._memory_mode = "hybrid" # "context", "tools", or "hybrid" self._prefetch_method = "recall" # "recall" or "reflect" self._client = None @@ -195,6 +200,31 @@ class HindsightMemoryProvider(MemoryProvider): self._prefetch_lock = threading.Lock() self._prefetch_thread = None self._sync_thread = None + self._session_id = "" + + # Tags + self._tags: list[str] | None = None + self._recall_tags: list[str] | None = None + self._recall_tags_match = "any" + + # Retain controls + self._auto_retain = True + self._retain_every_n_turns = 1 + self._retain_context = "conversation between Hermes Agent and the User" + self._turn_counter = 0 + self._session_turns: list[str] = [] # accumulates ALL turns for the session + + # Recall controls + self._auto_recall = True + self._recall_max_tokens = 4096 + self._recall_types: list[str] | None = None + self._recall_prompt_preamble = "" + self._recall_max_input_chars = 800 + + # Bank + self._bank_mission = "" + self._bank_retain_mission: str | None = None + self._retain_async = True @property def name(self) -> str: @@ -204,7 +234,7 @@ class HindsightMemoryProvider(MemoryProvider): try: cfg = _load_config() mode = cfg.get("mode", "cloud") - if mode == "local": + if mode in ("local", "local_embedded", "local_external"): return True has_key = bool(cfg.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")) has_url = bool(cfg.get("api_url") or os.environ.get("HINDSIGHT_API_URL", "")) @@ -228,73 +258,306 @@ class HindsightMemoryProvider(MemoryProvider): existing.update(values) config_path.write_text(json.dumps(existing, indent=2)) + def post_setup(self, hermes_home: str, config: dict) -> None: + """Custom setup wizard — installs only the deps needed for the selected mode.""" + import getpass + import subprocess + import shutil + import sys + from pathlib import Path + + from hermes_cli.config import save_config + + from hermes_cli.memory_setup import _curses_select + + print("\n Configuring Hindsight memory:\n") + + # Step 1: Mode selection + mode_items = [ + ("Cloud", "Hindsight Cloud API (lightweight, just needs an API key)"), + ("Local Embedded", "Run Hindsight locally (downloads ~200MB, needs LLM key)"), + ("Local External", "Connect to an existing Hindsight instance"), + ] + mode_idx = _curses_select(" Select mode", mode_items, default=0) + mode = ["cloud", "local_embedded", "local_external"][mode_idx] + + provider_config: dict = {"mode": mode} + env_writes: dict = {} + + # Step 2: Install/upgrade deps for selected mode + _MIN_CLIENT_VERSION = "0.4.22" + cloud_dep = f"hindsight-client>={_MIN_CLIENT_VERSION}" + local_dep = "hindsight-all" + if mode == "local_embedded": + deps_to_install = [local_dep] + elif mode == "local_external": + deps_to_install = [cloud_dep] + else: + deps_to_install = [cloud_dep] + + print(f"\n Checking dependencies...") + uv_path = shutil.which("uv") + if not uv_path: + print(" ⚠ uv not found — install it: curl -LsSf https://astral.sh/uv/install.sh | sh") + print(f" Then run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}") + else: + try: + subprocess.run( + [uv_path, "pip", "install", "--python", sys.executable, "--quiet", "--upgrade"] + deps_to_install, + check=True, timeout=120, capture_output=True, + ) + print(f" ✓ Dependencies up to date") + except Exception as e: + print(f" ⚠ Install failed: {e}") + print(f" Run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}") + + # Step 3: Mode-specific config + if mode == "cloud": + print(f"\n Get your API key at https://ui.hindsight.vectorize.io\n") + existing_key = os.environ.get("HINDSIGHT_API_KEY", "") + if existing_key: + masked = f"...{existing_key[-4:]}" if len(existing_key) > 4 else "set" + sys.stdout.write(f" API key (current: {masked}, blank to keep): ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + else: + sys.stdout.write(" API key: ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if api_key: + env_writes["HINDSIGHT_API_KEY"] = api_key + + val = input(f" API URL [{_DEFAULT_API_URL}]: ").strip() + if val: + provider_config["api_url"] = val + + elif mode == "local_external": + val = input(f" Hindsight API URL [{_DEFAULT_LOCAL_URL}]: ").strip() + provider_config["api_url"] = val or _DEFAULT_LOCAL_URL + + sys.stdout.write(" API key (optional, blank to skip): ") + sys.stdout.flush() + api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if api_key: + env_writes["HINDSIGHT_API_KEY"] = api_key + + else: # local_embedded + providers_list = list(_PROVIDER_DEFAULT_MODELS.keys()) + llm_items = [ + (p, f"default model: {_PROVIDER_DEFAULT_MODELS[p]}") + for p in providers_list + ] + llm_idx = _curses_select(" Select LLM provider", llm_items, default=0) + llm_provider = providers_list[llm_idx] + + provider_config["llm_provider"] = llm_provider + + if llm_provider == "openai_compatible": + val = input(" LLM endpoint URL (e.g. http://192.168.1.10:8080/v1): ").strip() + if val: + provider_config["llm_base_url"] = val + elif llm_provider == "openrouter": + provider_config["llm_base_url"] = "https://openrouter.ai/api/v1" + + default_model = _PROVIDER_DEFAULT_MODELS.get(llm_provider, "gpt-4o-mini") + val = input(f" LLM model [{default_model}]: ").strip() + provider_config["llm_model"] = val or default_model + + sys.stdout.write(" LLM API key: ") + sys.stdout.flush() + llm_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip() + if llm_key: + env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key + + # Step 4: Save everything + provider_config["bank_id"] = "hermes" + provider_config["recall_budget"] = "mid" + bank_id = "hermes" + config["memory"]["provider"] = "hindsight" + save_config(config) + + self.save_config(provider_config, hermes_home) + + if env_writes: + env_path = Path(hermes_home) / ".env" + env_path.parent.mkdir(parents=True, exist_ok=True) + existing_lines = [] + if env_path.exists(): + existing_lines = env_path.read_text().splitlines() + updated_keys = set() + new_lines = [] + for line in existing_lines: + key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None + if key_match and key_match in env_writes: + new_lines.append(f"{key_match}={env_writes[key_match]}") + updated_keys.add(key_match) + else: + new_lines.append(line) + for k, v in env_writes.items(): + if k not in updated_keys: + new_lines.append(f"{k}={v}") + env_path.write_text("\n".join(new_lines) + "\n") + + print(f"\n ✓ Hindsight memory configured ({mode} mode)") + if env_writes: + print(f" API keys saved to .env") + print(f"\n Start a new session to activate.\n") + def get_config_schema(self): return [ - {"key": "mode", "description": "Cloud API or local embedded mode", "default": "cloud", "choices": ["cloud", "local"]}, - {"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}}, + {"key": "mode", "description": "Connection mode", "default": "cloud", "choices": ["cloud", "local_embedded", "local_external"]}, + # Cloud mode + {"key": "api_url", "description": "Hindsight Cloud API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}}, {"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://ui.hindsight.vectorize.io", "when": {"mode": "cloud"}}, - {"key": "llm_provider", "description": "LLM provider for local mode", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "minimax", "ollama"], "when": {"mode": "local"}}, - {"key": "llm_api_key", "description": "LLM API key for local Hindsight", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local"}}, - {"key": "llm_base_url", "description": "LLM Base URL (e.g. for OpenRouter)", "default": "", "env_var": "HINDSIGHT_API_LLM_BASE_URL", "when": {"mode": "local"}}, - {"key": "llm_model", "description": "LLM model for local mode", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local"}}, + # Local external mode + {"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_LOCAL_URL, "when": {"mode": "local_external"}}, + {"key": "api_key", "description": "API key (optional)", "secret": True, "env_var": "HINDSIGHT_API_KEY", "when": {"mode": "local_external"}}, + # Local embedded mode + {"key": "llm_provider", "description": "LLM provider", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "openrouter", "minimax", "ollama", "lmstudio", "openai_compatible"], "when": {"mode": "local_embedded"}}, + {"key": "llm_base_url", "description": "Endpoint URL (e.g. http://192.168.1.10:8080/v1)", "default": "", "when": {"mode": "local_embedded", "llm_provider": "openai_compatible"}}, + {"key": "llm_api_key", "description": "LLM API key (optional for openai_compatible)", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local_embedded"}}, + {"key": "llm_model", "description": "LLM model", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local_embedded"}}, {"key": "bank_id", "description": "Memory bank name", "default": "hermes"}, - {"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]}, + {"key": "bank_mission", "description": "Mission/purpose description for the memory bank"}, + {"key": "bank_retain_mission", "description": "Custom extraction prompt for memory retention"}, + {"key": "recall_budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]}, {"key": "memory_mode", "description": "Memory integration mode", "default": "hybrid", "choices": ["hybrid", "context", "tools"]}, - {"key": "prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]}, + {"key": "recall_prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]}, + {"key": "tags", "description": "Tags applied when storing memories (comma-separated)", "default": ""}, + {"key": "recall_tags", "description": "Tags to filter when searching memories (comma-separated)", "default": ""}, + {"key": "recall_tags_match", "description": "Tag matching mode for recall", "default": "any", "choices": ["any", "all", "any_strict", "all_strict"]}, + {"key": "auto_recall", "description": "Automatically recall memories before each turn", "default": True}, + {"key": "auto_retain", "description": "Automatically retain conversation turns", "default": True}, + {"key": "retain_every_n_turns", "description": "Retain every N turns (1 = every turn)", "default": 1}, + {"key": "retain_async","description": "Process retain asynchronously on the Hindsight server", "default": True}, + {"key": "retain_context", "description": "Context label for retained memories", "default": "conversation between Hermes Agent and the User"}, + {"key": "recall_max_tokens", "description": "Maximum tokens for recall results", "default": 4096}, + {"key": "recall_max_input_chars", "description": "Maximum input query length for auto-recall", "default": 800}, + {"key": "recall_prompt_preamble", "description": "Custom preamble for recalled memories in context"}, ] def _get_client(self): """Return the cached Hindsight client (created once, reused).""" if self._client is None: - if self._mode == "local": + if self._mode == "local_embedded": from hindsight import HindsightEmbedded - # Disable __del__ on the class to prevent "attached to a - # different loop" errors during GC — we handle cleanup in - # shutdown() instead. HindsightEmbedded.__del__ = lambda self: None + llm_provider = self._config.get("llm_provider", "") + if llm_provider in ("openai_compatible", "openrouter"): + llm_provider = "openai" + logger.debug("Creating HindsightEmbedded client (profile=%s, provider=%s)", + self._config.get("profile", "hermes"), llm_provider) kwargs = dict( profile=self._config.get("profile", "hermes"), - llm_provider=self._config.get("llm_provider", ""), - llm_api_key=self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""), + llm_provider=llm_provider, + llm_api_key=self._config.get("llmApiKey") or self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""), llm_model=self._config.get("llm_model", ""), ) - base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "") - if base_url: - kwargs["llm_base_url"] = base_url + if self._llm_base_url: + kwargs["llm_base_url"] = self._llm_base_url self._client = HindsightEmbedded(**kwargs) else: from hindsight_client import Hindsight kwargs = {"base_url": self._api_url, "timeout": 30.0} if self._api_key: kwargs["api_key"] = self._api_key + logger.debug("Creating Hindsight cloud client (url=%s, has_key=%s)", + self._api_url, bool(self._api_key)) self._client = Hindsight(**kwargs) return self._client def initialize(self, session_id: str, **kwargs) -> None: + self._session_id = session_id + + # Check client version and auto-upgrade if needed + try: + from importlib.metadata import version as pkg_version + from packaging.version import Version + installed = pkg_version("hindsight-client") + if Version(installed) < Version(_MIN_CLIENT_VERSION): + logger.warning("hindsight-client %s is outdated (need >=%s), attempting upgrade...", + installed, _MIN_CLIENT_VERSION) + import shutil, subprocess, sys + uv_path = shutil.which("uv") + if uv_path: + try: + subprocess.run( + [uv_path, "pip", "install", "--python", sys.executable, + "--quiet", "--upgrade", f"hindsight-client>={_MIN_CLIENT_VERSION}"], + check=True, timeout=120, capture_output=True, + ) + logger.info("hindsight-client upgraded to >=%s", _MIN_CLIENT_VERSION) + except Exception as e: + logger.warning("Auto-upgrade failed: %s. Run: uv pip install 'hindsight-client>=%s'", + e, _MIN_CLIENT_VERSION) + else: + logger.warning("uv not found. Run: pip install 'hindsight-client>=%s'", _MIN_CLIENT_VERSION) + except Exception: + pass # packaging not available or other issue — proceed anyway + self._config = _load_config() self._mode = self._config.get("mode", "cloud") - self._api_key = self._config.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "") - default_url = _DEFAULT_LOCAL_URL if self._mode == "local" else _DEFAULT_API_URL + # "local" is a legacy alias for "local_embedded" + if self._mode == "local": + self._mode = "local_embedded" + self._api_key = self._config.get("apiKey") or self._config.get("api_key") or os.environ.get("HINDSIGHT_API_KEY", "") + default_url = _DEFAULT_LOCAL_URL if self._mode in ("local_embedded", "local_external") else _DEFAULT_API_URL self._api_url = self._config.get("api_url") or os.environ.get("HINDSIGHT_API_URL", default_url) + self._llm_base_url = self._config.get("llm_base_url", "") banks = self._config.get("banks", {}).get("hermes", {}) self._bank_id = self._config.get("bank_id") or banks.get("bankId", "hermes") - budget = self._config.get("budget") or banks.get("budget", "mid") + budget = self._config.get("recall_budget") or self._config.get("budget") or banks.get("budget", "mid") self._budget = budget if budget in _VALID_BUDGETS else "mid" memory_mode = self._config.get("memory_mode", "hybrid") self._memory_mode = memory_mode if memory_mode in ("context", "tools", "hybrid") else "hybrid" - prefetch_method = self._config.get("prefetch_method", "recall") + prefetch_method = self._config.get("recall_prefetch_method", "recall") self._prefetch_method = prefetch_method if prefetch_method in ("recall", "reflect") else "recall" - logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s", - self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method) + # Bank options + self._bank_mission = self._config.get("bank_mission", "") + self._bank_retain_mission = self._config.get("bank_retain_mission") or None + + # Tags + self._tags = self._config.get("tags") or None + self._recall_tags = self._config.get("recall_tags") or None + self._recall_tags_match = self._config.get("recall_tags_match", "any") + + # Retain controls + self._auto_retain = self._config.get("auto_retain", True) + self._retain_every_n_turns = max(1, int(self._config.get("retain_every_n_turns", 1))) + self._retain_context = self._config.get("retain_context", "conversation between Hermes Agent and the User") + + # Recall controls + self._auto_recall = self._config.get("auto_recall", True) + self._recall_max_tokens = int(self._config.get("recall_max_tokens", 4096)) + self._recall_types = self._config.get("recall_types") or None + self._recall_prompt_preamble = self._config.get("recall_prompt_preamble", "") + self._recall_max_input_chars = int(self._config.get("recall_max_input_chars", 800)) + self._retain_async = self._config.get("retain_async", True) + + _client_version = "unknown" + try: + from importlib.metadata import version as pkg_version + _client_version = pkg_version("hindsight-client") + except Exception: + pass + logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s, client=%s", + self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method, _client_version) + logger.debug("Hindsight config: auto_retain=%s, auto_recall=%s, retain_every_n=%d, " + "retain_async=%s, retain_context=%s, " + "recall_max_tokens=%d, recall_max_input_chars=%d, tags=%s, recall_tags=%s", + self._auto_retain, self._auto_recall, self._retain_every_n_turns, + self._retain_async, self._retain_context, + self._recall_max_tokens, self._recall_max_input_chars, + self._tags, self._recall_tags) # For local mode, start the embedded daemon in the background so it # doesn't block the chat. Redirect stdout/stderr to a log file to # prevent rich startup output from spamming the terminal. - if self._mode == "local": + if self._mode == "local_embedded": def _start_daemon(): import traceback log_dir = get_hermes_home() / "logs" @@ -320,6 +583,8 @@ class HindsightMemoryProvider(MemoryProvider): current_provider = self._config.get("llm_provider", "") current_model = self._config.get("llm_model", "") current_base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "") + # Map openai_compatible/openrouter → openai for the daemon (OpenAI wire format) + daemon_provider = "openai" if current_provider in ("openai_compatible", "openrouter") else current_provider # Read saved profile config saved = {} @@ -330,7 +595,7 @@ class HindsightMemoryProvider(MemoryProvider): saved[k.strip()] = v.strip() config_changed = ( - saved.get("HINDSIGHT_API_LLM_PROVIDER") != current_provider or + saved.get("HINDSIGHT_API_LLM_PROVIDER") != daemon_provider or saved.get("HINDSIGHT_API_LLM_MODEL") != current_model or saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key or saved.get("HINDSIGHT_API_LLM_BASE_URL", "") != current_base_url @@ -340,7 +605,7 @@ class HindsightMemoryProvider(MemoryProvider): # Write updated profile .env profile_env.parent.mkdir(parents=True, exist_ok=True) env_lines = ( - f"HINDSIGHT_API_LLM_PROVIDER={current_provider}\n" + f"HINDSIGHT_API_LLM_PROVIDER={daemon_provider}\n" f"HINDSIGHT_API_LLM_API_KEY={current_key}\n" f"HINDSIGHT_API_LLM_MODEL={current_model}\n" f"HINDSIGHT_API_LOG_LEVEL=info\n" @@ -388,47 +653,118 @@ class HindsightMemoryProvider(MemoryProvider): def prefetch(self, query: str, *, session_id: str = "") -> str: if self._prefetch_thread and self._prefetch_thread.is_alive(): + logger.debug("Prefetch: waiting for background thread to complete") self._prefetch_thread.join(timeout=3.0) with self._prefetch_lock: result = self._prefetch_result self._prefetch_result = "" if not result: + logger.debug("Prefetch: no results available") return "" - return f"## Hindsight Memory\n{result}" + logger.debug("Prefetch: returning %d chars of context", len(result)) + header = self._recall_prompt_preamble or ( + "# Hindsight Memory (persistent cross-session context)\n" + "Use this to answer questions about the user and prior sessions. " + "Do not call tools to look up information that is already present here." + ) + return f"{header}\n\n{result}" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: if self._memory_mode == "tools": + logger.debug("Prefetch: skipped (tools-only mode)") return + if not self._auto_recall: + logger.debug("Prefetch: skipped (auto_recall disabled)") + return + # Truncate query to max chars + if self._recall_max_input_chars and len(query) > self._recall_max_input_chars: + query = query[:self._recall_max_input_chars] + def _run(): try: client = self._get_client() if self._prefetch_method == "reflect": + logger.debug("Prefetch: calling reflect (bank=%s, query_len=%d)", self._bank_id, len(query)) resp = _run_sync(client.areflect(bank_id=self._bank_id, query=query, budget=self._budget)) text = resp.text or "" else: - resp = _run_sync(client.arecall(bank_id=self._bank_id, query=query, budget=self._budget)) - text = "\n".join(r.text for r in resp.results if r.text) if resp.results else "" + recall_kwargs: dict = { + "bank_id": self._bank_id, "query": query, + "budget": self._budget, "max_tokens": self._recall_max_tokens, + } + if self._recall_tags: + recall_kwargs["tags"] = self._recall_tags + recall_kwargs["tags_match"] = self._recall_tags_match + if self._recall_types: + recall_kwargs["types"] = self._recall_types + logger.debug("Prefetch: calling recall (bank=%s, query_len=%d, budget=%s)", + self._bank_id, len(query), self._budget) + resp = _run_sync(client.arecall(**recall_kwargs)) + num_results = len(resp.results) if resp.results else 0 + logger.debug("Prefetch: recall returned %d results", num_results) + text = "\n".join(f"- {r.text}" for r in resp.results if r.text) if resp.results else "" if text: with self._prefetch_lock: self._prefetch_result = text except Exception as e: - logger.debug("Hindsight prefetch failed: %s", e) + logger.debug("Hindsight prefetch failed: %s", e, exc_info=True) self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="hindsight-prefetch") self._prefetch_thread.start() def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Retain conversation turn in background (non-blocking).""" - combined = f"User: {user_content}\nAssistant: {assistant_content}" + """Retain conversation turn in background (non-blocking). + + Respects retain_every_n_turns for batching. + """ + if not self._auto_retain: + logger.debug("sync_turn: skipped (auto_retain disabled)") + return + + from datetime import datetime, timezone + now = datetime.now(timezone.utc).isoformat() + + messages = [ + {"role": "user", "content": user_content, "timestamp": now}, + {"role": "assistant", "content": assistant_content, "timestamp": now}, + ] + + turn = json.dumps(messages) + self._session_turns.append(turn) + self._turn_counter += 1 + + # Only retain every N turns + if self._turn_counter % self._retain_every_n_turns != 0: + logger.debug("sync_turn: buffered turn %d (will retain at turn %d)", + self._turn_counter, self._turn_counter + (self._retain_every_n_turns - self._turn_counter % self._retain_every_n_turns)) + return + + logger.debug("sync_turn: retaining %d turns, total session content %d chars", + len(self._session_turns), sum(len(t) for t in self._session_turns)) + # Send the ENTIRE session as a single JSON array (document_id deduplicates). + # Each element in _session_turns is a JSON string of that turn's messages. + content = "[" + ",".join(self._session_turns) + "]" def _sync(): try: client = self._get_client() - _run_sync(client.aretain( - bank_id=self._bank_id, content=combined, context="conversation" + item: dict = { + "content": content, + "context": self._retain_context, + } + if self._tags: + item["tags"] = self._tags + logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", + self._bank_id, self._session_id, self._retain_async, len(content), len(self._session_turns)) + _run_sync(client.aretain_batch( + bank_id=self._bank_id, + items=[item], + document_id=self._session_id, + retain_async=self._retain_async, )) + logger.debug("Hindsight retain succeeded") except Exception as e: - logger.warning("Hindsight sync failed: %s", e) + logger.warning("Hindsight sync failed: %s", e, exc_info=True) if self._sync_thread and self._sync_thread.is_alive(): self._sync_thread.join(timeout=5.0) @@ -453,12 +789,18 @@ class HindsightMemoryProvider(MemoryProvider): return tool_error("Missing required parameter: content") context = args.get("context") try: - _run_sync(client.aretain( - bank_id=self._bank_id, content=content, context=context - )) + retain_kwargs: dict = { + "bank_id": self._bank_id, "content": content, "context": context, + } + if self._tags: + retain_kwargs["tags"] = self._tags + logger.debug("Tool hindsight_retain: bank=%s, content_len=%d, context=%s", + self._bank_id, len(content), context) + _run_sync(client.aretain(**retain_kwargs)) + logger.debug("Tool hindsight_retain: success") return json.dumps({"result": "Memory stored successfully."}) except Exception as e: - logger.warning("hindsight_retain failed: %s", e) + logger.warning("hindsight_retain failed: %s", e, exc_info=True) return tool_error(f"Failed to store memory: {e}") elif tool_name == "hindsight_recall": @@ -466,15 +808,26 @@ class HindsightMemoryProvider(MemoryProvider): if not query: return tool_error("Missing required parameter: query") try: - resp = _run_sync(client.arecall( - bank_id=self._bank_id, query=query, budget=self._budget - )) + recall_kwargs: dict = { + "bank_id": self._bank_id, "query": query, "budget": self._budget, + "max_tokens": self._recall_max_tokens, + } + if self._recall_tags: + recall_kwargs["tags"] = self._recall_tags + recall_kwargs["tags_match"] = self._recall_tags_match + if self._recall_types: + recall_kwargs["types"] = self._recall_types + logger.debug("Tool hindsight_recall: bank=%s, query_len=%d, budget=%s", + self._bank_id, len(query), self._budget) + resp = _run_sync(client.arecall(**recall_kwargs)) + num_results = len(resp.results) if resp.results else 0 + logger.debug("Tool hindsight_recall: %d results", num_results) if not resp.results: return json.dumps({"result": "No relevant memories found."}) lines = [f"{i}. {r.text}" for i, r in enumerate(resp.results, 1)] return json.dumps({"result": "\n".join(lines)}) except Exception as e: - logger.warning("hindsight_recall failed: %s", e) + logger.warning("hindsight_recall failed: %s", e, exc_info=True) return tool_error(f"Failed to search memory: {e}") elif tool_name == "hindsight_reflect": @@ -482,24 +835,28 @@ class HindsightMemoryProvider(MemoryProvider): if not query: return tool_error("Missing required parameter: query") try: + logger.debug("Tool hindsight_reflect: bank=%s, query_len=%d, budget=%s", + self._bank_id, len(query), self._budget) resp = _run_sync(client.areflect( bank_id=self._bank_id, query=query, budget=self._budget )) + logger.debug("Tool hindsight_reflect: response_len=%d", len(resp.text or "")) return json.dumps({"result": resp.text or "No relevant memories found."}) except Exception as e: - logger.warning("hindsight_reflect failed: %s", e) + logger.warning("hindsight_reflect failed: %s", e, exc_info=True) return tool_error(f"Failed to reflect: {e}") return tool_error(f"Unknown tool: {tool_name}") def shutdown(self) -> None: + logger.debug("Hindsight shutdown: waiting for background threads") global _loop, _loop_thread for t in (self._prefetch_thread, self._sync_thread): if t and t.is_alive(): t.join(timeout=5.0) if self._client is not None: try: - if self._mode == "local": + if self._mode == "local_embedded": # Use the public close() API. The RuntimeError from # aiohttp's "attached to a different loop" is expected # and harmless — the daemon keeps running independently. diff --git a/plugins/memory/hindsight/plugin.yaml b/plugins/memory/hindsight/plugin.yaml index 7985189920..b12c09142b 100644 --- a/plugins/memory/hindsight/plugin.yaml +++ b/plugins/memory/hindsight/plugin.yaml @@ -2,9 +2,7 @@ name: hindsight version: 1.0.0 description: "Hindsight — long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval." pip_dependencies: - - hindsight-client - - hindsight-all -requires_env: - - HINDSIGHT_API_KEY + - "hindsight-client>=0.4.22" +requires_env: [] hooks: - on_session_end diff --git a/run_agent.py b/run_agent.py index fd1337cbb7..fc04706838 100644 --- a/run_agent.py +++ b/run_agent.py @@ -66,7 +66,7 @@ from model_tools import ( handle_function_call, check_toolset_requirements, ) -from tools.terminal_tool import cleanup_vm, get_active_env +from tools.terminal_tool import cleanup_vm, get_active_env, is_persistent_env from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget from tools.interrupt import set_interrupt as _set_interrupt from tools.browser_tool import cleanup_browser @@ -77,6 +77,7 @@ from hermes_constants import OPENROUTER_BASE_URL # Agent internals extracted to agent/ package for modularity from agent.memory_manager import build_memory_context_block from agent.retry_utils import jittered_backoff +from agent.error_classifier import classify_api_error, FailoverReason from agent.prompt_builder import ( DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, @@ -442,6 +443,13 @@ class AIAgent: for AI models that support function calling. """ + # ── Class-level context pressure dedup (survives across instances) ── + # The gateway creates a new AIAgent per message, so instance-level flags + # reset every time. This dict tracks {session_id: (warn_level, timestamp)} + # to suppress duplicate warnings within a cooldown window. + _context_pressure_last_warned: dict = {} + _CONTEXT_PRESSURE_COOLDOWN = 300 # seconds between re-warning same session + @property def base_url(self) -> str: return self._base_url @@ -673,7 +681,8 @@ class AIAgent: # Context pressure warnings: notify the USER (not the LLM) as context # fills up. Purely informational — displayed in CLI output and sent via # status_callback for gateway platforms. Does NOT inject into messages. - self._context_pressure_warned = False + # Tiered: fires at 85% and again at 95% of compaction threshold. + self._context_pressure_warned_at = 0.0 # highest tier already shown # Activity tracking — updated on each API call, tool execution, and # stream chunk. Used by the gateway timeout handler to report what the @@ -684,6 +693,10 @@ class AIAgent: self._current_tool: str | None = None self._api_call_count: int = 0 + # Rate limit tracking — updated from x-ratelimit-* response headers + # after each API call. Accessed by /usage slash command. + self._rate_limit_state: Optional["RateLimitState"] = None + # Centralized logging — agent.log (INFO+) and errors.log (WARNING+) # both live under ~/.hermes/logs/. Idempotent, so gateway mode # (which creates a new AIAgent per message) won't duplicate handlers. @@ -1687,9 +1700,25 @@ class AIAgent: return None def _cleanup_task_resources(self, task_id: str) -> None: - """Clean up VM and browser resources for a given task.""" + """Clean up VM and browser resources for a given task. + + Skips ``cleanup_vm`` when the active terminal environment is marked + persistent (``persistent_filesystem=True``) so that long-lived sandbox + containers survive between turns. The idle reaper in + ``terminal_tool._cleanup_inactive_envs`` still tears them down once + ``terminal.lifetime_seconds`` is exceeded. Non-persistent backends are + torn down per-turn as before to prevent resource leakage (the original + intent of this hook for the Morph backend, see commit fbd3a2fd). + """ try: - cleanup_vm(task_id) + if is_persistent_env(task_id): + if self.verbose_logging: + logging.debug( + f"Skipping per-turn cleanup_vm for persistent env {task_id}; " + f"idle reaper will handle it." + ) + else: + cleanup_vm(task_id) except Exception as e: if self.verbose_logging: logging.warning(f"Failed to cleanup VM for task {task_id}: {e}") @@ -2521,6 +2550,29 @@ class AIAgent: self._last_activity_ts = time.time() self._last_activity_desc = desc + def _capture_rate_limits(self, http_response: Any) -> None: + """Parse x-ratelimit-* headers from an HTTP response and cache the state. + + Called after each streaming API call. The httpx Response object is + available on the OpenAI SDK Stream via ``stream.response``. + """ + if http_response is None: + return + headers = getattr(http_response, "headers", None) + if not headers: + return + try: + from agent.rate_limit_tracker import parse_rate_limit_headers + state = parse_rate_limit_headers(headers, provider=self.provider) + if state is not None: + self._rate_limit_state = state + except Exception: + pass # Never let header parsing break the agent loop + + def get_rate_limit_state(self): + """Return the last captured RateLimitState, or None.""" + return self._rate_limit_state + def get_activity_summary(self) -> dict: """Return a snapshot of the agent's current activity for diagnostics. @@ -4375,6 +4427,11 @@ class AIAgent: self._touch_activity("waiting for provider response (streaming)") stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) + # Capture rate limit headers from the initial HTTP response. + # The OpenAI SDK Stream object exposes the underlying httpx + # response via .response before any chunks are consumed. + self._capture_rate_limits(getattr(stream, "response", None)) + content_parts: list = [] tool_calls_acc: dict = {} tool_gen_notified: set = set() @@ -4728,18 +4785,25 @@ class AIAgent: self._close_request_openai_client(request_client, reason="stream_request_complete") _stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0)) - # Scale the stale timeout for large contexts: slow models (like Opus) - # can legitimately think for minutes before producing the first token - # when the context is large. Without this, the stale detector kills - # healthy connections during the model's thinking phase, producing - # spurious RemoteProtocolError ("peer closed connection"). - _est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 - if _est_tokens > 100_000: - _stream_stale_timeout = max(_stream_stale_timeout_base, 300.0) - elif _est_tokens > 50_000: - _stream_stale_timeout = max(_stream_stale_timeout_base, 240.0) + # Local providers (Ollama, oMLX, llama-cpp) can take 300+ seconds + # for prefill on large contexts. Disable the stale detector unless + # the user explicitly set HERMES_STREAM_STALE_TIMEOUT. + if _stream_stale_timeout_base == 180.0 and self.base_url and is_local_endpoint(self.base_url): + _stream_stale_timeout = float("inf") + logger.debug("Local provider detected (%s) — stale stream timeout disabled", self.base_url) else: - _stream_stale_timeout = _stream_stale_timeout_base + # Scale the stale timeout for large contexts: slow models (like Opus) + # can legitimately think for minutes before producing the first token + # when the context is large. Without this, the stale detector kills + # healthy connections during the model's thinking phase, producing + # spurious RemoteProtocolError ("peer closed connection"). + _est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 + if _est_tokens > 100_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 300.0) + elif _est_tokens > 50_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 240.0) + else: + _stream_stale_timeout = _stream_stale_timeout_base t = threading.Thread(target=_call, daemon=True) t.start() @@ -5864,7 +5928,7 @@ class AIAgent: tools=[memory_tool_def], temperature=0.3, max_tokens=5120, - timeout=30.0, + # timeout resolved from auxiliary.flush_memories.timeout config ) except RuntimeError: _aux_available = False @@ -5896,7 +5960,10 @@ class AIAgent: "temperature": 0.3, **self._max_tokens_param(5120), } - response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0) + from agent.auxiliary_client import _get_task_timeout + response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create( + **api_kwargs, timeout=_get_task_timeout("flush_memories") + ) # Extract tool calls from the response, handling all API formats tool_calls = [] @@ -6003,6 +6070,15 @@ class AIAgent: except Exception as e: logger.warning("Session DB compression split failed — new session will NOT be indexed: %s", e) + # Warn on repeated compressions (quality degrades with each pass) + _cc = self.context_compressor.compression_count + if _cc >= 2: + self._vprint( + f"{self.log_prefix}⚠️ Session compressed {_cc} times — " + f"accuracy may degrade. Consider /new to start fresh.", + force=True, + ) + # Update token estimate after compaction so pressure calculations # use the post-compression count, not the stale pre-compression one. _compressed_est = ( @@ -6015,12 +6091,16 @@ class AIAgent: # Only reset the pressure warning if compression actually brought # us below the warning level (85% of threshold). When compression # can't reduce enough (e.g. threshold is very low, or system prompt - # alone exceeds the warning level), keep the flag set to prevent + # alone exceeds the warning level), keep the tier set to prevent # spamming the user with repeated warnings every loop iteration. if self.context_compressor.threshold_tokens > 0: _post_progress = _compressed_est / self.context_compressor.threshold_tokens if _post_progress < 0.85: - self._context_pressure_warned = False + self._context_pressure_warned_at = 0.0 + # Clear class-level dedup for this session so a fresh + # warning cycle can start if context grows again. + _sid = self.session_id or "default" + AIAgent._context_pressure_last_warned.pop(_sid, None) # Clear the file-read dedup cache. After compression the original # read content is summarised away — if the model re-reads the same @@ -7202,6 +7282,7 @@ class AIAgent: length_continue_retries = 0 truncated_response_prefix = "" compression_attempts = 0 + _turn_exit_reason = "unknown" # Diagnostic: why the loop ended # Clear any stale interrupt state at start self.clear_interrupt() @@ -7226,6 +7307,7 @@ class AIAgent: # Check for interrupt request (e.g., user sent new message) if self._interrupt_requested: interrupted = True + _turn_exit_reason = "interrupted_by_user" if not self.quiet_mode: self._safe_print("\n⚡ Breaking out of tool loop due to interrupt...") break @@ -7234,6 +7316,7 @@ class AIAgent: self._api_call_count = api_call_count self._touch_activity(f"starting API call #{api_call_count}") if not self.iteration_budget.consume(): + _turn_exit_reason = "budget_exhausted" if not self.quiet_mode: self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)") break @@ -7938,6 +8021,25 @@ class AIAgent: status_code = getattr(api_error, "status_code", None) error_context = self._extract_api_error_context(api_error) + + # ── Classify the error for structured recovery decisions ── + _compressor = getattr(self, "context_compressor", None) + _ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000 + classified = classify_api_error( + api_error, + provider=getattr(self, "provider", "") or "", + model=getattr(self, "model", "") or "", + approx_tokens=approx_tokens, + context_length=_ctx_len, + num_messages=len(api_messages) if api_messages else 0, + ) + logger.debug( + "Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s", + classified.reason.value, classified.status_code, + classified.retryable, classified.should_compress, + classified.should_rotate_credential, classified.should_fallback, + ) + recovered_with_pool, has_retried_429 = self._recover_with_credential_pool( status_code=status_code, has_retried_429=has_retried_429, @@ -8000,27 +8102,24 @@ class AIAgent: # from all messages so the next retry sends no thinking # blocks at all. One-shot — don't retry infinitely. if ( - self.api_mode == "anthropic_messages" - and status_code == 400 + classified.reason == FailoverReason.thinking_signature and not thinking_sig_retry_attempted ): - _err_msg_lower = str(api_error).lower() - if "signature" in _err_msg_lower and "thinking" in _err_msg_lower: - thinking_sig_retry_attempted = True - for _m in messages: - if isinstance(_m, dict): - _m.pop("reasoning_details", None) - self._vprint( - f"{self.log_prefix}⚠️ Thinking block signature invalid — " - f"stripped all thinking blocks, retrying...", - force=True, - ) - logging.warning( - "%sThinking block signature recovery: stripped " - "reasoning_details from %d messages", - self.log_prefix, len(messages), - ) - continue + thinking_sig_retry_attempted = True + for _m in messages: + if isinstance(_m, dict): + _m.pop("reasoning_details", None) + self._vprint( + f"{self.log_prefix}⚠️ Thinking block signature invalid — " + f"stripped all thinking blocks, retrying...", + force=True, + ) + logging.warning( + "%sThinking block signature recovery: stripped " + "reasoning_details from %d messages", + self.log_prefix, len(messages), + ) + continue retry_count += 1 elapsed_time = time.time() - api_start_time @@ -8077,14 +8176,7 @@ class AIAgent: # is NOT a transient rate limit — retrying or switching # credentials won't help. Reduce context to 200k (the # standard tier) and compress. - # Only applies to Sonnet — Opus 1M is general access. - _is_long_context_tier_error = ( - status_code == 429 - and "extra usage" in error_msg - and "long context" in error_msg - and "sonnet" in self.model.lower() - ) - if _is_long_context_tier_error: + if classified.reason == FailoverReason.long_context_tier: _reduced_ctx = 200000 compressor = self.context_compressor old_ctx = compressor.context_length @@ -8129,13 +8221,9 @@ class AIAgent: # When a fallback model is configured, switch immediately instead # of burning through retries with exponential backoff -- the # primary provider won't recover within the retry window. - is_rate_limited = ( - status_code == 429 - or "rate limit" in error_msg - or "too many requests" in error_msg - or "rate_limit" in error_msg - or "usage limit" in error_msg - or "quota" in error_msg + is_rate_limited = classified.reason in ( + FailoverReason.rate_limit, + FailoverReason.billing, ) if is_rate_limited and self._fallback_index < len(self._fallback_chain): # Don't eagerly fallback if credential pool rotation may @@ -8151,10 +8239,7 @@ class AIAgent: continue is_payload_too_large = ( - status_code == 413 - or 'request entity too large' in error_msg - or 'payload too large' in error_msg - or 'error code: 413' in error_msg + classified.reason == FailoverReason.payload_too_large ) if is_payload_too_large: @@ -8198,64 +8283,12 @@ class AIAgent: } # Check for context-length errors BEFORE generic 4xx handler. - # Local backends (LM Studio, Ollama, llama.cpp) often return - # HTTP 400 with messages like "Context size has been exceeded" - # which must trigger compression, not an immediate abort. - is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'context size', 'maximum context', - 'token limit', 'too many tokens', 'reduce the length', - 'exceeds the limit', 'context window', - 'request entity too large', # OpenRouter/Nous 413 safety net - 'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum" - 'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording - ]) - - # Fallback heuristic: Anthropic sometimes returns a generic - # 400 invalid_request_error with just "Error" as the message - # when the context is too large. If the error message is very - # short/generic AND the session is large, treat it as a - # probable context-length error and attempt compression rather - # than aborting. This prevents an infinite failure loop where - # each failed message gets persisted, making the session even - # larger. (#1630) - if not is_context_length_error and status_code == 400: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80 - is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error" - if is_large_session and is_generic_error: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Generic 400 with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as probable context overflow.", - force=True, - ) - - # Server disconnects on large sessions are often caused by - # the request exceeding the provider's context/payload limit - # without a proper HTTP error response. Treat these as - # context-length errors to trigger compression rather than - # burning through retries that will all fail the same way. - # This breaks the death spiral: disconnect → no token data - # → no compression → bigger session → more disconnects. - # (#2153) - if not is_context_length_error and not status_code: - _is_server_disconnect = ( - 'server disconnected' in error_msg - or 'peer closed connection' in error_msg - or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError') - ) - if _is_server_disconnect: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - _is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200 - if _is_large: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Server disconnected with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as context-length error, attempting compression.", - force=True, - ) + # The classifier detects context overflow from: explicit error + # messages, generic 400 + large session heuristic (#1630), and + # server disconnect + large session pattern (#2153). + is_context_length_error = ( + classified.reason == FailoverReason.context_overflow + ) if is_context_length_error: compressor = self.context_compressor @@ -8327,35 +8360,30 @@ class AIAgent: "partial": True } - # Check for non-retryable client errors (4xx HTTP status codes). - # These indicate a problem with the request itself (bad model ID, - # invalid API key, forbidden, etc.) and will never succeed on retry. - # Note: 413 and context-length errors are excluded — handled above. - # 429 (rate limit) is transient and MUST be retried with backoff. - # 529 (Anthropic overloaded) is also transient. - # Also catch local validation errors (ValueError, TypeError) — these - # are programming bugs, not transient failures. - # Exclude UnicodeEncodeError — it's a ValueError subclass but is - # handled separately by the surrogate sanitization path above. - _RETRYABLE_STATUS_CODES = {413, 429, 529} + # Check for non-retryable client errors. The classifier + # already accounts for 413, 429, 529 (transient), context + # overflow, and generic-400 heuristics. Local validation + # errors (ValueError, TypeError) are programming bugs. is_local_validation_error = ( isinstance(api_error, (ValueError, TypeError)) and not isinstance(api_error, UnicodeEncodeError) ) - # Detect generic 400s from Anthropic OAuth (transient server-side failures). - # Real invalid_request_error responses include a descriptive message; - # transient ones contain only "Error" or are empty. (ref: issue #1608) - _err_body = getattr(api_error, "body", None) or {} - _err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "") - _is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", "")) - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400 - is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [ - 'error code: 401', 'error code: 403', - 'error code: 404', 'error code: 422', - 'is not a valid model', 'invalid model', 'model not found', - 'invalid api key', 'invalid_api_key', 'authentication', - 'unauthorized', 'forbidden', 'not found', - ])) and not is_context_length_error + is_client_error = ( + is_local_validation_error + or ( + not classified.retryable + and not classified.should_compress + and classified.reason not in ( + FailoverReason.rate_limit, + FailoverReason.billing, + FailoverReason.overloaded, + FailoverReason.context_overflow, + FailoverReason.payload_too_large, + FailoverReason.long_context_tier, + FailoverReason.thinking_signature, + ) + ) + ) and not is_context_length_error if is_client_error: # Try fallback before aborting — a different provider @@ -8375,7 +8403,7 @@ class AIAgent: self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True) self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True) # Actionable guidance for common auth errors - if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg: + if classified.is_auth or classified.reason == FailoverReason.billing: if _provider == "openai-codex" and status_code == 401: self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True) self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True) @@ -8535,6 +8563,7 @@ class AIAgent: # If the API call was interrupted, skip response processing if interrupted: + _turn_exit_reason = "interrupted_during_api_call" break if restart_with_compressed_messages: @@ -8554,6 +8583,7 @@ class AIAgent: # (e.g. repeated context-length errors that exhausted retry_count), # the `response` variable is still None. Break out cleanly. if response is None: + _turn_exit_reason = "all_retries_exhausted_no_response" print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.") self._persist_session(messages, conversation_history) break @@ -8960,13 +8990,34 @@ class AIAgent: # compaction fires, not the raw context window. # Does not inject into messages — just prints to CLI output # and fires status_callback for gateway platforms. + # Tiered: 85% (orange) and 95% (red/critical). if _compressor.threshold_tokens > 0: _compaction_progress = _real_tokens / _compressor.threshold_tokens - if _compaction_progress >= 0.85 and not self._context_pressure_warned: - self._context_pressure_warned = True - self._emit_context_pressure(_compaction_progress, _compressor) + # Determine the warning tier for this progress level + _warn_tier = 0.0 + if _compaction_progress >= 0.95: + _warn_tier = 0.95 + elif _compaction_progress >= 0.85: + _warn_tier = 0.85 + if _warn_tier > self._context_pressure_warned_at: + # Class-level dedup: check if this session was already + # warned at this tier within the cooldown window. + _sid = self.session_id or "default" + _last = AIAgent._context_pressure_last_warned.get(_sid) + _now = time.time() + if _last is None or _last[0] < _warn_tier or (_now - _last[1]) >= self._CONTEXT_PRESSURE_COOLDOWN: + self._context_pressure_warned_at = _warn_tier + AIAgent._context_pressure_last_warned[_sid] = (_warn_tier, _now) + self._emit_context_pressure(_compaction_progress, _compressor) + # Evict stale entries (older than 2x cooldown) + _cutoff = _now - self._CONTEXT_PRESSURE_COOLDOWN * 2 + AIAgent._context_pressure_last_warned = { + k: v for k, v in AIAgent._context_pressure_last_warned.items() + if v[1] > _cutoff + } if self.compression_enabled and _compressor.should_compress(_real_tokens): + self._safe_print(" ⟳ compacting context…") messages, active_system_prompt = self._compress_context( messages, system_message, approx_tokens=self.context_compressor.last_prompt_tokens, @@ -8996,6 +9047,7 @@ class AIAgent: # instead of wasting API calls on retries that won't help. fallback = getattr(self, '_last_content_with_tools', None) if fallback: + _turn_exit_reason = "fallback_prior_turn_content" logger.debug("Empty follow-up after tool calls — using prior turn content as final response") self._last_content_with_tools = None self._empty_content_retries = 0 @@ -9041,8 +9093,28 @@ class AIAgent: self._save_session_log(messages) continue - # Exhausted prefill attempts or no structured - # reasoning — fall through to "(empty)" terminal. + # ── Empty response retry (no reasoning) ────── + # Model returned nothing — no content, no + # structured reasoning, no tool calls. Common + # with open models (transient provider issues, + # rate limits, sampling flukes). Silently retry + # up to 3 times before giving up. Skip when + # content has inline tags (model chose + # to reason, just no visible text). + _truly_empty = not final_response.strip() + if _truly_empty and not _has_structured and self._empty_content_retries < 3: + self._empty_content_retries += 1 + self._vprint( + f"{self.log_prefix}↻ Empty response (no content or reasoning) " + f"— retrying ({self._empty_content_retries}/3)", + force=True, + ) + continue + + # Exhausted prefill attempts, empty retries, or + # structured reasoning with no content — + # fall through to "(empty)" terminal. + _turn_exit_reason = "empty_response_exhausted" reasoning_text = self._extract_reasoning(assistant_message) assistant_msg = self._build_assistant_message(assistant_message, finish_reason) assistant_msg["content"] = "(empty)" @@ -9052,7 +9124,7 @@ class AIAgent: reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text self._vprint(f"{self.log_prefix}ℹ️ Reasoning-only response (no visible content). Reasoning: {reasoning_preview}") else: - self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning).") + self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning) after 3 retries.") final_response = "(empty)" break @@ -9114,6 +9186,7 @@ class AIAgent: messages.append(final_msg) + _turn_exit_reason = f"text_response(finish_reason={finish_reason})" if not self.quiet_mode: self._safe_print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)") break @@ -9163,6 +9236,7 @@ class AIAgent: # If we're near the limit, break to avoid infinite loops if api_call_count >= self.max_iterations - 1: + _turn_exit_reason = f"error_near_max_iterations({error_msg[:80]})" final_response = f"I apologize, but I encountered repeated errors: {error_msg}" # Append as assistant so the history stays valid for # session resume (avoids consecutive user messages). @@ -9173,6 +9247,7 @@ class AIAgent: api_call_count >= self.max_iterations or self.iteration_budget.remaining <= 0 ): + _turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})" if self.iteration_budget.remaining <= 0 and not self.quiet_mode: print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)") final_response = self._handle_max_iterations(messages, api_call_count) @@ -9189,6 +9264,49 @@ class AIAgent: # Persist session to both JSON log and SQLite self._persist_session(messages, conversation_history) + # ── Turn-exit diagnostic log ───────────────────────────────────── + # Always logged at INFO so agent.log captures WHY every turn ended. + # When the last message is a tool result (agent was mid-work), log + # at WARNING — this is the "just stops" scenario users report. + _last_msg_role = messages[-1].get("role") if messages else None + _last_tool_name = None + if _last_msg_role == "tool": + # Walk back to find the assistant message with the tool call + for _m in reversed(messages): + if _m.get("role") == "assistant" and _m.get("tool_calls"): + _tcs = _m["tool_calls"] + if _tcs and isinstance(_tcs[0], dict): + _last_tool_name = _tcs[-1].get("function", {}).get("name") + break + + _turn_tool_count = sum( + 1 for m in messages + if isinstance(m, dict) and m.get("role") == "assistant" and m.get("tool_calls") + ) + _resp_len = len(final_response) if final_response else 0 + _budget_used = self.iteration_budget.used if self.iteration_budget else 0 + _budget_max = self.iteration_budget.max_total if self.iteration_budget else 0 + + _diag_msg = ( + "Turn ended: reason=%s model=%s api_calls=%d/%d budget=%d/%d " + "tool_turns=%d last_msg_role=%s response_len=%d session=%s" + ) + _diag_args = ( + _turn_exit_reason, self.model, api_call_count, self.max_iterations, + _budget_used, _budget_max, + _turn_tool_count, _last_msg_role, _resp_len, + self.session_id or "none", + ) + + if _last_msg_role == "tool" and not interrupted: + # Agent was mid-work — this is the "just stops" case. + logger.warning( + "Turn ended with pending tool result (agent may appear stuck). " + + _diag_msg + " last_tool=%s", + *_diag_args, _last_tool_name, + ) + else: + logger.info(_diag_msg, *_diag_args) # Plugin hook: post_llm_call # Fired once per turn after the tool-calling loop completes. diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index dd02ad23ab..3723378998 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -77,6 +77,20 @@ class TestReadCodexAccessToken: result = _read_codex_access_token() assert result == "tok-123" + def test_pool_without_selected_entry_falls_back_to_auth_store(self, tmp_path, monkeypatch): + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + valid_jwt = "eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjk5OTk5OTk5OTl9.sig" + with patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), \ + patch("hermes_cli.auth._read_codex_tokens", return_value={ + "tokens": {"access_token": valid_jwt, "refresh_token": "refresh"} + }): + result = _read_codex_access_token() + + assert result == valid_jwt + def test_missing_returns_none(self, tmp_path, monkeypatch): hermes_home = tmp_path / "hermes" hermes_home.mkdir(parents=True, exist_ok=True) @@ -238,6 +252,24 @@ class TestAnthropicOAuthFlag: assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled" +class TestTryCodex: + def test_pool_without_selected_entry_falls_back_to_auth_store(self): + with ( + patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), + patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-auth-token"), + patch("agent.auxiliary_client.OpenAI") as mock_openai, + ): + mock_openai.return_value = MagicMock() + from agent.auxiliary_client import _try_codex + + client, model = _try_codex() + + assert client is not None + assert model == "gpt-5.2-codex" + assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token" + assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex" + + class TestExpiredCodexFallback: """Test that expired Codex tokens don't block the auto chain.""" diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index 257cf90395..42f6de0fd3 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -324,7 +324,10 @@ class TestCompressWithClient: with patch("agent.context_compressor.get_model_context_length", return_value=100000): c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) - # Last head message (index 1) is "assistant" → summary should be "user" + # Last head message (index 1) is "assistant" → summary should be "user". + # With min_tail=3, tail = last 3 messages (indices 5-7). + # head_last=assistant, tail_first=assistant → summary_role="user", no collision. + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "user", "content": "msg 0"}, {"role": "assistant", "content": "msg 1"}, @@ -332,6 +335,8 @@ class TestCompressWithClient: {"role": "assistant", "content": "msg 3"}, {"role": "user", "content": "msg 4"}, {"role": "assistant", "content": "msg 5"}, + {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) @@ -460,8 +465,10 @@ class TestCompressWithClient: c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) # Head: [system, user] → last head = user - # Tail: [assistant, user] → first tail = assistant + # Tail: [assistant, user, assistant] → first tail = assistant # summary_role="assistant" collides with tail, "user" collides with head → merge + # With min_tail=3, tail = last 3 messages (indices 5-7). + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "system", "content": "system prompt"}, {"role": "user", "content": "msg 1"}, @@ -470,6 +477,7 @@ class TestCompressWithClient: {"role": "assistant", "content": "msg 4"}, # compressed {"role": "assistant", "content": "msg 5"}, # tail start {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) @@ -481,7 +489,7 @@ class TestCompressWithClient: if r1 in ("user", "assistant") and r2 in ("user", "assistant"): assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}" - # The summary should be merged into the first tail message (assistant) + # The summary should be merged into the first tail message (assistant at index 5) first_tail = [m for m in result if "msg 5" in (m.get("content") or "")] assert len(first_tail) == 1 assert "summary text" in first_tail[0]["content"] @@ -496,14 +504,18 @@ class TestCompressWithClient: with patch("agent.context_compressor.get_model_context_length", return_value=100000): c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2) - # Head=assistant, Tail=assistant → summary_role="user", no collision + # Head=assistant, Tail=assistant → summary_role="user", no collision. + # With min_tail=3, tail = last 3 messages (indices 5-7). + # Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6. msgs = [ {"role": "user", "content": "msg 0"}, {"role": "assistant", "content": "msg 1"}, {"role": "user", "content": "msg 2"}, {"role": "assistant", "content": "msg 3"}, - {"role": "assistant", "content": "msg 4"}, - {"role": "user", "content": "msg 5"}, + {"role": "user", "content": "msg 4"}, + {"role": "assistant", "content": "msg 5"}, + {"role": "user", "content": "msg 6"}, + {"role": "assistant", "content": "msg 7"}, ] with patch("agent.context_compressor.call_llm", return_value=mock_response): result = c.compress(msgs) @@ -600,3 +612,158 @@ class TestSummaryTargetRatio: with patch("agent.context_compressor.get_model_context_length", return_value=100_000): c = ContextCompressor(model="test", quiet_mode=True) assert c.protect_last_n == 20 + + +class TestTokenBudgetTailProtection: + """Tests for token-budget-based tail protection (PR #6240). + + The core change: tail protection is now based on a token budget rather + than a fixed message count. This prevents large tool outputs from + blocking compaction. + """ + + @pytest.fixture() + def budget_compressor(self): + """Compressor with known token budget for tail protection tests.""" + with patch("agent.context_compressor.get_model_context_length", return_value=200_000): + c = ContextCompressor( + model="test/model", + threshold_percent=0.50, # 100K threshold + protect_first_n=2, + protect_last_n=20, + quiet_mode=True, + ) + return c + + def test_large_tool_outputs_no_longer_block_compaction(self, budget_compressor): + """The motivating scenario: 20 messages with large tool outputs should + NOT prevent compaction. With message-count tail protection they would + all be protected, leaving nothing to summarize.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "Start task"}, + {"role": "assistant", "content": "On it"}, + ] + # Add 20 messages with large tool outputs (~5K chars each ≈ 1250 tokens) + for i in range(10): + messages.append({ + "role": "assistant", "content": None, + "tool_calls": [{"function": {"name": f"tool_{i}", "arguments": "{}"}}], + }) + messages.append({ + "role": "tool", "content": "x" * 5000, + "tool_call_id": f"call_{i}", + }) + # Add 3 recent small messages + messages.append({"role": "user", "content": "What's the status?"}) + messages.append({"role": "assistant", "content": "Here's what I found..."}) + messages.append({"role": "user", "content": "Continue"}) + + # The tail cut should NOT protect all 20 tool messages + head_end = c.protect_first_n + cut = c._find_tail_cut_by_tokens(messages, head_end) + tail_size = len(messages) - cut + # With token budget, the tail should be much smaller than 20+ + assert tail_size < 20, f"Tail {tail_size} messages — large tool outputs are blocking compaction" + # But at least 3 (hard minimum) + assert tail_size >= 3 + + def test_min_tail_always_3_messages(self, budget_compressor): + """Even with a tiny token budget, at least 3 messages are protected.""" + c = budget_compressor + # Override to a tiny budget + c.tail_token_budget = 10 + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": "working on it"}, + {"role": "user", "content": "more work"}, + {"role": "assistant", "content": "done"}, + {"role": "user", "content": "thanks"}, + ] + head_end = 2 + cut = c._find_tail_cut_by_tokens(messages, head_end) + tail_size = len(messages) - cut + assert tail_size >= 3, f"Tail is only {tail_size} messages, min should be 3" + + def test_soft_ceiling_allows_oversized_message(self, budget_compressor): + """The 1.5x soft ceiling allows an oversized message to be included + rather than splitting it.""" + c = budget_compressor + # Set a small budget — 500 tokens + c.tail_token_budget = 500 + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "read the file"}, + # This message is ~600 tokens (> budget of 500, but < 1.5x = 750) + {"role": "assistant", "content": "a" * 2400}, + {"role": "user", "content": "short"}, + {"role": "assistant", "content": "short reply"}, + {"role": "user", "content": "continue"}, + ] + head_end = 2 + cut = c._find_tail_cut_by_tokens(messages, head_end) + # The oversized message at index 3 should NOT be the cut point + # because 1.5x ceiling = 750 tokens and accumulated would be ~610 + # (short msgs + oversized msg) which is < 750 + tail_size = len(messages) - cut + assert tail_size >= 3 + + def test_small_conversation_still_compresses(self, budget_compressor): + """With the new min of 8 messages (head=2 + 3 + 1 guard + 2 middle), + a small but compressible conversation should still compress.""" + c = budget_compressor + # 9 messages: head(2) + 4 middle + 3 tail = compressible + messages = [] + for i in range(9): + role = "user" if i % 2 == 0 else "assistant" + messages.append({"role": role, "content": f"Message {i}"}) + + # Should not early-return (needs > protect_first_n + 3 + 1 = 6) + # Mock the summary generation to avoid real API call + with patch.object(c, "_generate_summary", return_value="Summary of conversation"): + result = c.compress(messages, current_tokens=90_000) + # Should have compressed (fewer messages than original) + assert len(result) < len(messages) + + def test_prune_with_token_budget(self, budget_compressor): + """_prune_old_tool_results with protect_tail_tokens respects the budget.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "big.txt"}'}}]}, + {"role": "tool", "content": "x" * 10000, "tool_call_id": "c1"}, # ~2500 tokens + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "small.txt"}'}}]}, + {"role": "tool", "content": "y" * 10000, "tool_call_id": "c2"}, # ~2500 tokens + {"role": "user", "content": "short recent message"}, + {"role": "assistant", "content": "short reply"}, + ] + # With a 1000-token budget, only the last couple messages should be protected + result, pruned = c._prune_old_tool_results( + messages, protect_tail_count=2, protect_tail_tokens=1000, + ) + # At least one old tool result should have been pruned + assert pruned >= 1 + + def test_prune_without_token_budget_uses_message_count(self, budget_compressor): + """Without protect_tail_tokens, falls back to message-count behavior.""" + c = budget_compressor + messages = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": None, + "tool_calls": [{"function": {"name": "tool", "arguments": "{}"}}]}, + {"role": "tool", "content": "x" * 5000, "tool_call_id": "c1"}, + {"role": "user", "content": "recent"}, + {"role": "assistant", "content": "reply"}, + ] + # protect_tail_count=3 means last 3 messages protected + result, pruned = c._prune_old_tool_results( + messages, protect_tail_count=3, + ) + # Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4) + # so it might or might not be pruned depending on boundary + assert isinstance(pruned, int) diff --git a/tests/agent/test_credential_pool.py b/tests/agent/test_credential_pool.py index 891ab68a82..c3bde95156 100644 --- a/tests/agent/test_credential_pool.py +++ b/tests/agent/test_credential_pool.py @@ -214,6 +214,42 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch): assert entry.last_status == "ok" +def test_exhausted_402_entry_resets_after_one_hour(tmp_path, monkeypatch): + """402-exhausted credentials recover after 1 hour, not 24.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openrouter": [ + { + "id": "cred-1", + "label": "primary", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "***", + "base_url": "https://openrouter.ai/api/v1", + "last_status": "exhausted", + "last_status_at": time.time() - 3700, # ~1h2m ago + "last_error_code": 402, + } + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("openrouter") + entry = pool.select() + + assert entry is not None + assert entry.id == "cred-1" + assert entry.last_status == "ok" + + def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) _write_auth_store( diff --git a/tests/agent/test_error_classifier.py b/tests/agent/test_error_classifier.py new file mode 100644 index 0000000000..da248f8218 --- /dev/null +++ b/tests/agent/test_error_classifier.py @@ -0,0 +1,750 @@ +"""Tests for agent.error_classifier — structured API error classification.""" + +import pytest +from agent.error_classifier import ( + ClassifiedError, + FailoverReason, + classify_api_error, + _extract_status_code, + _extract_error_body, + _extract_error_code, + _classify_402, +) + + +# ── Helper: mock API errors ──────────────────────────────────────────── + +class MockAPIError(Exception): + """Simulates an OpenAI SDK APIStatusError.""" + def __init__(self, message, status_code=None, body=None): + super().__init__(message) + self.status_code = status_code + self.body = body or {} + + +class MockTransportError(Exception): + """Simulates a transport-level error with a specific type name.""" + pass + + +class ReadTimeout(MockTransportError): + pass + + +class ConnectError(MockTransportError): + pass + + +class RemoteProtocolError(MockTransportError): + pass + + +class ServerDisconnectedError(MockTransportError): + pass + + +# ── Test: FailoverReason enum ────────────────────────────────────────── + +class TestFailoverReason: + def test_all_reasons_have_string_values(self): + for reason in FailoverReason: + assert isinstance(reason.value, str) + + def test_enum_members_exist(self): + expected = { + "auth", "auth_permanent", "billing", "rate_limit", + "overloaded", "server_error", "timeout", + "context_overflow", "payload_too_large", + "model_not_found", "format_error", + "thinking_signature", "long_context_tier", "unknown", + } + actual = {r.value for r in FailoverReason} + assert expected == actual + + +# ── Test: ClassifiedError ────────────────────────────────────────────── + +class TestClassifiedError: + def test_is_auth_property(self): + e1 = ClassifiedError(reason=FailoverReason.auth) + assert e1.is_auth is True + + e2 = ClassifiedError(reason=FailoverReason.auth_permanent) + assert e2.is_auth is True + + e3 = ClassifiedError(reason=FailoverReason.billing) + assert e3.is_auth is False + + def test_is_transient_property(self): + transient_reasons = [ + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ] + for reason in transient_reasons: + e = ClassifiedError(reason=reason) + assert e.is_transient is True, f"{reason} should be transient" + + non_transient = [ + FailoverReason.auth, + FailoverReason.billing, + FailoverReason.model_not_found, + FailoverReason.format_error, + ] + for reason in non_transient: + e = ClassifiedError(reason=reason) + assert e.is_transient is False, f"{reason} should NOT be transient" + + def test_defaults(self): + e = ClassifiedError(reason=FailoverReason.unknown) + assert e.retryable is True + assert e.should_compress is False + assert e.should_rotate_credential is False + assert e.should_fallback is False + assert e.status_code is None + assert e.message == "" + + +# ── Test: Status code extraction ─────────────────────────────────────── + +class TestExtractStatusCode: + def test_from_status_code_attr(self): + e = MockAPIError("fail", status_code=429) + assert _extract_status_code(e) == 429 + + def test_from_status_attr(self): + class ErrWithStatus(Exception): + status = 503 + assert _extract_status_code(ErrWithStatus()) == 503 + + def test_from_cause_chain(self): + inner = MockAPIError("inner", status_code=401) + outer = Exception("outer") + outer.__cause__ = inner + assert _extract_status_code(outer) == 401 + + def test_none_when_missing(self): + assert _extract_status_code(Exception("generic")) is None + + def test_rejects_non_http_status(self): + """Integers outside 100-599 on .status should be ignored.""" + class ErrWeirdStatus(Exception): + status = 42 + assert _extract_status_code(ErrWeirdStatus()) is None + + +# ── Test: Error body extraction ──────────────────────────────────────── + +class TestExtractErrorBody: + def test_from_body_attr(self): + e = MockAPIError("fail", body={"error": {"message": "bad"}}) + assert _extract_error_body(e) == {"error": {"message": "bad"}} + + def test_empty_when_no_body(self): + assert _extract_error_body(Exception("generic")) == {} + + +# ── Test: Error code extraction ──────────────────────────────────────── + +class TestExtractErrorCode: + def test_from_nested_error_code(self): + body = {"error": {"code": "rate_limit_exceeded"}} + assert _extract_error_code(body) == "rate_limit_exceeded" + + def test_from_nested_error_type(self): + body = {"error": {"type": "invalid_request_error"}} + assert _extract_error_code(body) == "invalid_request_error" + + def test_from_top_level_code(self): + body = {"code": "model_not_found"} + assert _extract_error_code(body) == "model_not_found" + + def test_empty_when_no_code(self): + assert _extract_error_code({}) == "" + assert _extract_error_code({"error": {"message": "oops"}}) == "" + + +# ── Test: 402 disambiguation ─────────────────────────────────────────── + +class TestClassify402: + """The critical 402 billing vs rate_limit disambiguation.""" + + def test_billing_exhaustion(self): + """Plain 402 = billing.""" + result = _classify_402( + "payment required", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + + def test_transient_usage_limit(self): + """402 with 'usage limit' + 'try again' = rate limit, not billing.""" + result = _classify_402( + "usage limit exceeded. try again in 5 minutes", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + assert result.should_rotate_credential is True + + def test_quota_with_retry(self): + """402 with 'quota' + 'retry' = rate limit.""" + result = _classify_402( + "quota exceeded, please retry after the window resets", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + + def test_quota_without_retry(self): + """402 with just 'quota' but no transient signal = billing.""" + result = _classify_402( + "quota exceeded", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + def test_insufficient_credits(self): + result = _classify_402( + "insufficient credits to complete request", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + +# ── Test: Full classification pipeline ───────────────────────────────── + +class TestClassifyApiError: + """End-to-end classification tests.""" + + # ── Auth errors ── + + def test_401_classified_as_auth(self): + e = MockAPIError("Unauthorized", status_code=401) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.auth + assert result.should_rotate_credential is True + # 401 is non-retryable on its own — credential rotation runs + # before the retryability check in the agent loop. + assert result.retryable is False + assert result.should_fallback is True + + def test_403_classified_as_auth(self): + e = MockAPIError("Forbidden", status_code=403) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.auth + assert result.should_fallback is True + + def test_403_key_limit_classified_as_billing(self): + """OpenRouter 403 'key limit exceeded' is billing, not auth.""" + e = MockAPIError("Key limit exceeded for this key", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + assert result.should_fallback is True + + def test_403_spending_limit_classified_as_billing(self): + e = MockAPIError("spending limit reached", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + + # ── Billing ── + + def test_402_plain_billing(self): + e = MockAPIError("Payment Required", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + assert result.retryable is False + + def test_402_transient_usage_limit(self): + e = MockAPIError("usage limit exceeded, try again later", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.retryable is True + + # ── Rate limit ── + + def test_429_rate_limit(self): + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.should_fallback is True + + # ── Server errors ── + + def test_500_server_error(self): + e = MockAPIError("Internal Server Error", status_code=500) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + assert result.retryable is True + + def test_502_server_error(self): + e = MockAPIError("Bad Gateway", status_code=502) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_503_overloaded(self): + e = MockAPIError("Service Unavailable", status_code=503) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + def test_529_anthropic_overloaded(self): + e = MockAPIError("Overloaded", status_code=529) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + # ── Model not found ── + + def test_404_model_not_found(self): + e = MockAPIError("model not found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + assert result.should_fallback is True + assert result.retryable is False + + def test_404_generic(self): + e = MockAPIError("Not Found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + # ── Payload too large ── + + def test_413_payload_too_large(self): + e = MockAPIError("Request Entity Too Large", status_code=413) + result = classify_api_error(e) + assert result.reason == FailoverReason.payload_too_large + assert result.should_compress is True + + # ── Context overflow ── + + def test_400_context_length(self): + e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_400_too_many_tokens(self): + e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_prompt_too_long(self): + e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_large_session(self): + """Generic 400 with large session → context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=100000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_small_session_is_format_error(self): + """Generic 400 with small session → format error, not context overflow.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=1000, context_length=200000) + assert result.reason == FailoverReason.format_error + + # ── Server disconnect + large session ── + + def test_disconnect_large_session_context_overflow(self): + """Server disconnect with large session → context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_disconnect_small_session_timeout(self): + """Server disconnect with small session → timeout.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=5000, context_length=200000) + assert result.reason == FailoverReason.timeout + + # ── Provider-specific: Anthropic thinking signature ── + + def test_anthropic_thinking_signature(self): + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.thinking_signature + assert result.retryable is True + + def test_non_anthropic_400_with_signature_not_classified_as_thinking(self): + """400 with 'signature' but from non-Anthropic → format error.""" + e = MockAPIError("invalid signature", status_code=400) + result = classify_api_error(e, provider="openrouter", approx_tokens=0) + # Without "thinking" in the message, it shouldn't be thinking_signature + assert result.reason != FailoverReason.thinking_signature + + # ── Provider-specific: Anthropic long-context tier ── + + def test_anthropic_long_context_tier(self): + e = MockAPIError( + "Extra usage is required for long context requests over 200k tokens", + status_code=429, + ) + result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4") + assert result.reason == FailoverReason.long_context_tier + assert result.should_compress is True + + def test_normal_429_not_long_context(self): + """Normal 429 without 'extra usage' + 'long context' → rate_limit.""" + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.rate_limit + + # ── Transport errors ── + + def test_read_timeout(self): + e = ReadTimeout("Read timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + assert result.retryable is True + + def test_connect_error(self): + e = ConnectError("Connection refused") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_connection_error_builtin(self): + e = ConnectionError("Connection reset by peer") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_timeout_error_builtin(self): + e = TimeoutError("timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + # ── Error code classification ── + + def test_error_code_resource_exhausted(self): + e = MockAPIError( + "Resource exhausted", + body={"error": {"code": "resource_exhausted", "message": "Too many requests"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_error_code_model_not_found(self): + e = MockAPIError( + "Model not available", + body={"error": {"code": "model_not_found"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_error_code_context_length_exceeded(self): + e = MockAPIError( + "Context too large", + body={"error": {"code": "context_length_exceeded"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Message-only patterns (no status code) ── + + def test_message_billing_pattern(self): + e = Exception("insufficient credits to complete this request") + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_message_rate_limit_pattern(self): + e = Exception("rate limit reached for this model") + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_message_auth_pattern(self): + e = Exception("invalid api key provided") + result = classify_api_error(e) + assert result.reason == FailoverReason.auth + + def test_message_model_not_found_pattern(self): + e = Exception("gpt-99 is not a valid model") + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_message_context_overflow_pattern(self): + e = Exception("maximum context length exceeded") + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Unknown / fallback ── + + def test_generic_exception_is_unknown(self): + e = Exception("something weird happened") + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + # ── Format error ── + + def test_400_descriptive_format_error(self): + """400 with descriptive message (not context overflow) → format error.""" + e = MockAPIError( + "Invalid value for parameter 'temperature': must be between 0 and 2", + status_code=400, + body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}}, + ) + result = classify_api_error(e, approx_tokens=1000) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + def test_422_format_error(self): + e = MockAPIError("Unprocessable Entity", status_code=422) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + # ── Peer closed + large session ── + + def test_peer_closed_large_session(self): + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=130000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + # ── Chinese error messages ── + + def test_chinese_context_overflow(self): + e = MockAPIError("超过最大长度限制", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Result metadata ── + + def test_provider_and_model_in_result(self): + e = MockAPIError("fail", status_code=500) + result = classify_api_error(e, provider="openrouter", model="gpt-5") + assert result.provider == "openrouter" + assert result.model == "gpt-5" + assert result.status_code == 500 + + def test_message_extracted(self): + e = MockAPIError( + "outer", + status_code=500, + body={"error": {"message": "Internal server error occurred"}}, + ) + result = classify_api_error(e) + assert result.message == "Internal server error occurred" + + +# ── Test: Adversarial / edge cases (from live testing) ───────────────── + +class TestAdversarialEdgeCases: + """Edge cases discovered during live testing with real SDK objects.""" + + def test_empty_exception_message(self): + result = classify_api_error(Exception("")) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + def test_500_with_none_body(self): + e = MockAPIError("fail", status_code=500, body=None) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_non_dict_body(self): + """Some providers return strings instead of JSON.""" + class StringBodyError(Exception): + status_code = 400 + body = "just a string" + result = classify_api_error(StringBodyError("bad")) + assert result.reason == FailoverReason.format_error + + def test_list_body(self): + class ListBodyError(Exception): + status_code = 500 + body = [{"error": "something"}] + result = classify_api_error(ListBodyError("server error")) + assert result.reason == FailoverReason.server_error + + def test_circular_cause_chain(self): + """Must not infinite-loop on circular __cause__.""" + e = Exception("circular") + e.__cause__ = e + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + + def test_three_level_cause_chain(self): + inner = MockAPIError("inner", status_code=429) + middle = Exception("middle") + middle.__cause__ = inner + outer = RuntimeError("outer") + outer.__cause__ = middle + result = classify_api_error(outer) + assert result.status_code == 429 + assert result.reason == FailoverReason.rate_limit + + def test_400_with_rate_limit_text(self): + """Some providers send rate limits as 400 instead of 429.""" + e = MockAPIError( + "rate limit policy", + status_code=400, + body={"error": {"message": "rate limit exceeded on this model"}}, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_400_with_billing_text(self): + """Some providers send billing errors as 400.""" + e = MockAPIError( + "billing", + status_code=400, + body={"error": {"message": "insufficient credits for this request"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_200_with_error_body(self): + """200 status with error in body — should be unknown, not crash.""" + class WeirdSuccess(Exception): + status_code = 200 + body = {"error": {"message": "loading"}} + result = classify_api_error(WeirdSuccess("model loading")) + assert result.reason == FailoverReason.unknown + + def test_ollama_context_size_exceeded(self): + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "context size has been exceeded"}}, + ) + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.context_overflow + + def test_connection_refused_error(self): + e = ConnectionRefusedError("Connection refused: localhost:11434") + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.timeout + + def test_body_message_enrichment(self): + """Body message must be included in pattern matching even when + str(error) doesn't contain it (OpenAI SDK APIStatusError).""" + e = MockAPIError( + "Usage limit", # str(e) = "usage limit" + status_code=402, + body={"error": {"message": "Usage limit reached, try again in 5 minutes"}}, + ) + result = classify_api_error(e) + # "try again" is only in body, not in str(e) + assert result.reason == FailoverReason.rate_limit + + def test_disconnect_pattern_ordering(self): + """Disconnect + large session must beat generic transport catch.""" + class FakeRemoteProtocol(Exception): + pass + # Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_credit_balance_too_low(self): + e = MockAPIError( + "Credits low", + status_code=402, + body={"error": {"message": "Your credit balance is too low"}}, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.billing + + def test_deepseek_402_chinese(self): + """Chinese billing message should still match billing patterns.""" + # "余额不足" doesn't match English billing patterns, but 402 defaults to billing + e = MockAPIError("余额不足", status_code=402) + result = classify_api_error(e, provider="deepseek") + assert result.reason == FailoverReason.billing + + def test_openrouter_wrapped_context_overflow_in_metadata_raw(self): + """OpenRouter wraps provider errors in metadata.raw JSON string.""" + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "code": 400, + "metadata": { + "raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter", approx_tokens=10000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_openrouter_wrapped_rate_limit_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_thinking_signature_via_openrouter(self): + """Thinking signature errors proxied through OpenRouter must be caught.""" + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + # provider is openrouter, not anthropic — old code missed this + result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4") + assert result.reason == FailoverReason.thinking_signature + + def test_generic_400_large_by_message_count(self): + """Many small messages (>80) should trigger context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + # Low token count but high message count + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=100, + ) + assert result.reason == FailoverReason.context_overflow + + def test_disconnect_large_by_message_count(self): + """Server disconnect with 200+ messages should trigger context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=250, + ) + assert result.reason == FailoverReason.context_overflow + + def test_openrouter_wrapped_model_not_found_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"The model gpt-99 does not exist"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.model_not_found diff --git a/tests/agent/test_rate_limit_tracker.py b/tests/agent/test_rate_limit_tracker.py new file mode 100644 index 0000000000..caef785678 --- /dev/null +++ b/tests/agent/test_rate_limit_tracker.py @@ -0,0 +1,212 @@ +"""Tests for agent.rate_limit_tracker — header parsing and formatting.""" + +import time +import pytest +from agent.rate_limit_tracker import ( + RateLimitBucket, + RateLimitState, + parse_rate_limit_headers, + format_rate_limit_display, + format_rate_limit_compact, + _fmt_count, + _fmt_seconds, + _bar, +) + + +# ── Sample headers from Nous inference API ────────────────────────────── + +NOUS_HEADERS = { + "x-ratelimit-limit-requests": "800", + "x-ratelimit-limit-requests-1h": "33600", + "x-ratelimit-limit-tokens": "8000000", + "x-ratelimit-limit-tokens-1h": "336000000", + "x-ratelimit-remaining-requests": "795", + "x-ratelimit-remaining-requests-1h": "33590", + "x-ratelimit-remaining-tokens": "7999500", + "x-ratelimit-remaining-tokens-1h": "335999000", + "x-ratelimit-reset-requests": "45.5", + "x-ratelimit-reset-requests-1h": "3500.0", + "x-ratelimit-reset-tokens": "42.3", + "x-ratelimit-reset-tokens-1h": "3490.0", +} + + +class TestParseHeaders: + def test_basic_parsing(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + assert state is not None + assert state.provider == "nous" + assert state.has_data + + assert state.requests_min.limit == 800 + assert state.requests_min.remaining == 795 + assert state.requests_min.reset_seconds == 45.5 + + assert state.requests_hour.limit == 33600 + assert state.requests_hour.remaining == 33590 + + assert state.tokens_min.limit == 8000000 + assert state.tokens_min.remaining == 7999500 + + assert state.tokens_hour.limit == 336000000 + assert state.tokens_hour.remaining == 335999000 + assert state.tokens_hour.reset_seconds == 3490.0 + + def test_no_headers(self): + state = parse_rate_limit_headers({}) + assert state is None + + def test_partial_headers(self): + headers = { + "x-ratelimit-limit-requests": "100", + "x-ratelimit-remaining-requests": "50", + } + state = parse_rate_limit_headers(headers) + assert state is not None + assert state.requests_min.limit == 100 + assert state.requests_min.remaining == 50 + # Missing fields default to 0 + assert state.tokens_min.limit == 0 + + def test_non_rate_limit_headers_ignored(self): + headers = { + "content-type": "application/json", + "server": "nginx", + } + state = parse_rate_limit_headers(headers) + assert state is None + + def test_malformed_values(self): + headers = { + "x-ratelimit-limit-requests": "not-a-number", + "x-ratelimit-remaining-requests": "", + "x-ratelimit-reset-requests": "abc", + } + state = parse_rate_limit_headers(headers) + assert state is not None + assert state.requests_min.limit == 0 + assert state.requests_min.remaining == 0 + assert state.requests_min.reset_seconds == 0.0 + + +class TestBucket: + def test_used(self): + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=45.0, captured_at=time.time()) + assert b.used == 5 + + def test_usage_pct(self): + b = RateLimitBucket(limit=100, remaining=20, reset_seconds=30.0, captured_at=time.time()) + assert b.usage_pct == pytest.approx(80.0) + + def test_usage_pct_zero_limit(self): + b = RateLimitBucket(limit=0, remaining=0) + assert b.usage_pct == 0.0 + + def test_remaining_seconds_now(self): + now = time.time() + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=60.0, captured_at=now - 10) + # ~50 seconds should remain + assert 49 <= b.remaining_seconds_now <= 51 + + def test_remaining_seconds_expired(self): + b = RateLimitBucket(limit=800, remaining=795, reset_seconds=30.0, captured_at=time.time() - 60) + assert b.remaining_seconds_now == 0.0 + + +class TestFormatting: + def test_fmt_count_millions(self): + assert _fmt_count(8000000) == "8.0M" + assert _fmt_count(336000000) == "336.0M" + + def test_fmt_count_thousands(self): + assert _fmt_count(33600) == "33.6K" + assert _fmt_count(1500) == "1.5K" + + def test_fmt_count_small(self): + assert _fmt_count(800) == "800" + assert _fmt_count(0) == "0" + + def test_fmt_seconds_short(self): + assert _fmt_seconds(45) == "45s" + assert _fmt_seconds(0) == "0s" + + def test_fmt_seconds_minutes(self): + assert _fmt_seconds(125) == "2m 5s" + assert _fmt_seconds(120) == "2m" + + def test_fmt_seconds_hours(self): + assert _fmt_seconds(3660) == "1h 1m" + assert _fmt_seconds(3600) == "1h" + + def test_bar(self): + bar = _bar(50.0, width=10) + assert bar == "[█████░░░░░]" + assert _bar(0.0, width=10) == "[░░░░░░░░░░]" + assert _bar(100.0, width=10) == "[██████████]" + + def test_format_display_no_data(self): + state = RateLimitState() + result = format_rate_limit_display(state) + assert "No rate limit data" in result + + def test_format_display_with_data(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + result = format_rate_limit_display(state) + assert "Nous" in result + assert "Requests/min" in result + assert "Requests/hr" in result + assert "Tokens/min" in result + assert "Tokens/hr" in result + assert "resets in" in result + + def test_format_display_warning_on_high_usage(self): + headers = { + **NOUS_HEADERS, + "x-ratelimit-remaining-requests": "50", # 750/800 used = 93.75% + } + state = parse_rate_limit_headers(headers) + result = format_rate_limit_display(state) + assert "⚠" in result + + def test_format_compact(self): + state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous") + result = format_rate_limit_compact(state) + assert "RPM:" in result + assert "RPH:" in result + assert "TPM:" in result + assert "TPH:" in result + assert "resets" in result + + def test_format_compact_no_data(self): + state = RateLimitState() + result = format_rate_limit_compact(state) + assert "No rate limit data" in result + + +class TestAgentIntegration: + """Test that AIAgent captures rate limit state correctly.""" + + def test_capture_rate_limits_from_headers(self): + """Simulate the header capture path without a real API call.""" + import sys + import os + # Use a mock httpx-like response + class MockResponse: + headers = NOUS_HEADERS + + # Import AIAgent minimally + from unittest.mock import MagicMock, patch + + # Test the parsing directly + state = parse_rate_limit_headers(MockResponse.headers, provider="nous") + assert state is not None + assert state.requests_min.limit == 800 + assert state.tokens_hour.limit == 336000000 + + def test_capture_rate_limits_none_response(self): + """_capture_rate_limits should handle None gracefully.""" + from agent.rate_limit_tracker import parse_rate_limit_headers + # None should not crash + result = parse_rate_limit_headers({}) + assert result is None diff --git a/tests/agent/test_subdirectory_hints.py b/tests/agent/test_subdirectory_hints.py index 7d2bc607c8..7c1a74e66c 100644 --- a/tests/agent/test_subdirectory_hints.py +++ b/tests/agent/test_subdirectory_hints.py @@ -3,6 +3,7 @@ import os import pytest from pathlib import Path +from unittest.mock import patch from agent.subdirectory_hints import SubdirectoryHintTracker @@ -189,3 +190,45 @@ class TestSubdirectoryHintTracker: "terminal", {"command": "curl https://example.com/frontend/api"} ) assert result is None + + +class TestPermissionErrorHandling: + """Regression tests for PermissionError in filesystem checks (ref #6214).""" + + def test_is_valid_subdir_permission_error(self, tmp_path): + """_is_valid_subdir should return False when is_dir() raises PermissionError.""" + tracker = SubdirectoryHintTracker(working_dir=str(tmp_path)) + restricted = tmp_path / "restricted" + restricted.mkdir() + with patch.object(Path, "is_dir", side_effect=PermissionError("Permission denied")): + assert tracker._is_valid_subdir(restricted) is False + + def test_load_hints_permission_error_on_is_file(self, tmp_path): + """_load_hints_for_directory should skip files when is_file() raises PermissionError.""" + tracker = SubdirectoryHintTracker(working_dir=str(tmp_path)) + restricted = tmp_path / "restricted" + restricted.mkdir() + original_is_file = Path.is_file + def patched_is_file(self): + if "restricted" in str(self): + raise PermissionError("Permission denied") + return original_is_file(self) + with patch.object(Path, "is_file", patched_is_file): + result = tracker._load_hints_for_directory(restricted) + assert result is None + + def test_check_tool_call_survives_inaccessible_path(self, project): + """Full check_tool_call should not crash when a path is inaccessible.""" + tracker = SubdirectoryHintTracker(working_dir=str(project)) + original_is_dir = Path.is_dir + def patched_is_dir(self): + if "backend" in str(self) and "src" not in str(self): + raise PermissionError("Permission denied") + return original_is_dir(self) + with patch.object(Path, "is_dir", patched_is_dir): + # Should not raise — gracefully skip the inaccessible directory + result = tracker.check_tool_call( + "read_file", {"path": str(project / "backend" / "src" / "main.py")} + ) + # Result may be None (backend skipped) — the key point is no crash + assert result is None or isinstance(result, str) diff --git a/tests/cli/test_cli_approval_ui.py b/tests/cli/test_cli_approval_ui.py index 9b2e0bbb26..63e03b9ab9 100644 --- a/tests/cli/test_cli_approval_ui.py +++ b/tests/cli/test_cli_approval_ui.py @@ -2,22 +2,65 @@ import queue import threading import time from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import cli as cli_module from cli import HermesCLI +class _FakeBuffer: + def __init__(self, text="", cursor_position=None): + self.text = text + self.cursor_position = len(text) if cursor_position is None else cursor_position + + def reset(self, append_to_history=False): + self.text = "" + self.cursor_position = 0 + + def _make_cli_stub(): cli = HermesCLI.__new__(HermesCLI) cli._approval_state = None cli._approval_deadline = 0 cli._approval_lock = threading.Lock() + cli._sudo_state = None + cli._sudo_deadline = 0 + cli._modal_input_snapshot = None cli._invalidate = MagicMock() - cli._app = SimpleNamespace(invalidate=MagicMock()) + cli._app = SimpleNamespace(invalidate=MagicMock(), current_buffer=_FakeBuffer()) return cli class TestCliApprovalUi: + def test_sudo_prompt_restores_existing_draft_after_response(self): + cli = _make_cli_stub() + cli._app.current_buffer = _FakeBuffer("draft command", cursor_position=5) + result = {} + + def _run_callback(): + result["value"] = cli._sudo_password_callback() + + with patch.object(cli_module, "_cprint"): + thread = threading.Thread(target=_run_callback, daemon=True) + thread.start() + + deadline = time.time() + 2 + while cli._sudo_state is None and time.time() < deadline: + time.sleep(0.01) + + assert cli._sudo_state is not None + assert cli._app.current_buffer.text == "" + + cli._app.current_buffer.text = "secret" + cli._app.current_buffer.cursor_position = len("secret") + cli._sudo_state["response_queue"].put("secret") + + thread.join(timeout=2) + + assert result["value"] == "secret" + assert cli._app.current_buffer.text == "draft command" + assert cli._app.current_buffer.cursor_position == 5 + def test_approval_callback_includes_view_for_long_commands(self): cli = _make_cli_stub() command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress" diff --git a/tests/gateway/test_bluebubbles.py b/tests/gateway/test_bluebubbles.py new file mode 100644 index 0000000000..939a69ff15 --- /dev/null +++ b/tests/gateway/test_bluebubbles.py @@ -0,0 +1,361 @@ +"""Tests for the BlueBubbles iMessage gateway adapter.""" +import pytest + +from gateway.config import Platform, PlatformConfig + + +def _make_adapter(monkeypatch, **extra): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.platforms.bluebubbles import BlueBubblesAdapter + + cfg = PlatformConfig( + enabled=True, + extra={ + "server_url": "http://localhost:1234", + "password": "secret", + **extra, + }, + ) + return BlueBubblesAdapter(cfg) + + +class TestBlueBubblesPlatformEnum: + def test_bluebubbles_enum_exists(self): + assert Platform.BLUEBUBBLES.value == "bluebubbles" + + +class TestBlueBubblesConfigLoading: + def test_apply_env_overrides_bluebubbles(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + monkeypatch.setenv("BLUEBUBBLES_WEBHOOK_PORT", "9999") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES in config.platforms + bc = config.platforms[Platform.BLUEBUBBLES] + assert bc.enabled is True + assert bc.extra["server_url"] == "http://localhost:1234" + assert bc.extra["password"] == "secret" + assert bc.extra["webhook_port"] == 9999 + + def test_connected_platforms_includes_bluebubbles(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES in config.get_connected_platforms() + + def test_home_channel_set_from_env(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + monkeypatch.setenv("BLUEBUBBLES_HOME_CHANNEL", "user@example.com") + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + hc = config.platforms[Platform.BLUEBUBBLES].home_channel + assert hc is not None + assert hc.chat_id == "user@example.com" + + def test_not_connected_without_password(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False) + from gateway.config import GatewayConfig, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + assert Platform.BLUEBUBBLES not in config.get_connected_platforms() + + +class TestBlueBubblesHelpers: + def test_check_requirements(self, monkeypatch): + monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234") + monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret") + from gateway.platforms.bluebubbles import check_bluebubbles_requirements + + assert check_bluebubbles_requirements() is True + + def test_format_message_strips_markdown(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("**Hello** `world`") == "Hello world" + + def test_strip_markdown_headers(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("## Heading\ntext") == "Heading\ntext" + + def test_strip_markdown_links(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + assert adapter.format_message("[click here](http://example.com)") == "click here" + + def test_init_normalizes_webhook_path(self, monkeypatch): + adapter = _make_adapter(monkeypatch, webhook_path="bluebubbles-webhook") + assert adapter.webhook_path == "/bluebubbles-webhook" + + def test_init_preserves_leading_slash(self, monkeypatch): + adapter = _make_adapter(monkeypatch, webhook_path="/my-hook") + assert adapter.webhook_path == "/my-hook" + + def test_server_url_normalized(self, monkeypatch): + adapter = _make_adapter(monkeypatch, server_url="http://localhost:1234/") + assert adapter.server_url == "http://localhost:1234" + + def test_server_url_adds_scheme(self, monkeypatch): + adapter = _make_adapter(monkeypatch, server_url="localhost:1234") + assert adapter.server_url == "http://localhost:1234" + + +class TestBlueBubblesWebhookParsing: + def test_webhook_prefers_chat_guid_over_message_guid(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "guid": "MESSAGE-GUID", + "chatGuid": "iMessage;-;user@example.com", + "chatIdentifier": "user@example.com", + } + record = adapter._extract_payload_record(payload) or {} + chat_guid = adapter._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + assert chat_guid == "iMessage;-;user@example.com" + + def test_webhook_can_fall_back_to_sender_when_chat_fields_missing(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "data": { + "guid": "MESSAGE-GUID", + "text": "hello", + "handle": {"address": "user@example.com"}, + "isFromMe": False, + } + } + record = adapter._extract_payload_record(payload) or {} + chat_guid = adapter._value( + record.get("chatGuid"), + payload.get("chatGuid"), + record.get("chat_guid"), + payload.get("chat_guid"), + payload.get("guid"), + ) + chat_identifier = adapter._value( + record.get("chatIdentifier"), + record.get("identifier"), + payload.get("chatIdentifier"), + payload.get("identifier"), + ) + sender = ( + adapter._value( + record.get("handle", {}).get("address") + if isinstance(record.get("handle"), dict) + else None, + record.get("sender"), + record.get("from"), + record.get("address"), + ) + or chat_identifier + or chat_guid + ) + if not (chat_guid or chat_identifier) and sender: + chat_identifier = sender + assert chat_identifier == "user@example.com" + + def test_extract_payload_record_accepts_list_data(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = { + "type": "new-message", + "data": [ + { + "text": "hello", + "chatGuid": "iMessage;-;user@example.com", + "chatIdentifier": "user@example.com", + } + ], + } + record = adapter._extract_payload_record(payload) + assert record == payload["data"][0] + + def test_extract_payload_record_dict_data(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = {"data": {"text": "hello", "chatGuid": "iMessage;-;+1234"}} + record = adapter._extract_payload_record(payload) + assert record["text"] == "hello" + + def test_extract_payload_record_fallback_to_message(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + payload = {"message": {"text": "hello"}} + record = adapter._extract_payload_record(payload) + assert record["text"] == "hello" + + +class TestBlueBubblesGuidResolution: + def test_raw_guid_returned_as_is(self, monkeypatch): + """If target already contains ';' it's a raw GUID — return unchanged.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._resolve_chat_guid("iMessage;-;user@example.com") + ) + assert result == "iMessage;-;user@example.com" + + def test_empty_target_returns_none(self, monkeypatch): + adapter = _make_adapter(monkeypatch) + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._resolve_chat_guid("") + ) + assert result is None + + +class TestBlueBubblesToolsetIntegration: + def test_toolset_exists(self): + from toolsets import TOOLSETS + + assert "hermes-bluebubbles" in TOOLSETS + + def test_toolset_in_gateway_composite(self): + from toolsets import TOOLSETS + + gateway = TOOLSETS["hermes-gateway"] + assert "hermes-bluebubbles" in gateway["includes"] + + +class TestBlueBubblesPromptHint: + def test_platform_hint_exists(self): + from agent.prompt_builder import PLATFORM_HINTS + + assert "bluebubbles" in PLATFORM_HINTS + hint = PLATFORM_HINTS["bluebubbles"] + assert "iMessage" in hint + assert "plain text" in hint + + +class TestBlueBubblesAttachmentDownload: + """Verify _download_attachment routes to the correct cache helper.""" + + def test_download_image_uses_image_cache(self, monkeypatch): + """Image MIME routes to cache_image_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + import httpx + + # Mock the HTTP client response + class MockResponse: + status_code = 200 + content = b"\x89PNG\r\n\x1a\n" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_image(data, ext): + nonlocal cached_path + cached_path = f"/tmp/test_image{ext}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_image_from_bytes", + mock_cache_image, + ) + + att_meta = {"mimeType": "image/png", "transferName": "photo.png"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-123", att_meta) + ) + assert result == "/tmp/test_image.png" + + def test_download_audio_uses_audio_cache(self, monkeypatch): + """Audio MIME routes to cache_audio_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + class MockResponse: + status_code = 200 + content = b"fake-audio-data" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_audio(data, ext): + nonlocal cached_path + cached_path = f"/tmp/test_audio{ext}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_audio_from_bytes", + mock_cache_audio, + ) + + att_meta = {"mimeType": "audio/mpeg", "transferName": "voice.mp3"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-456", att_meta) + ) + assert result == "/tmp/test_audio.mp3" + + def test_download_document_uses_document_cache(self, monkeypatch): + """Non-image/audio MIME routes to cache_document_from_bytes.""" + adapter = _make_adapter(monkeypatch) + import asyncio + + class MockResponse: + status_code = 200 + content = b"fake-doc-data" + + def raise_for_status(self): + pass + + async def mock_get(*args, **kwargs): + return MockResponse() + + adapter.client = type("MockClient", (), {"get": mock_get})() + + cached_path = None + + def mock_cache_doc(data, filename): + nonlocal cached_path + cached_path = f"/tmp/{filename}" + return cached_path + + monkeypatch.setattr( + "gateway.platforms.bluebubbles.cache_document_from_bytes", + mock_cache_doc, + ) + + att_meta = {"mimeType": "application/pdf", "transferName": "report.pdf"} + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid-789", att_meta) + ) + assert result == "/tmp/report.pdf" + + def test_download_returns_none_without_client(self, monkeypatch): + """No client → returns None gracefully.""" + adapter = _make_adapter(monkeypatch) + adapter.client = None + import asyncio + + result = asyncio.get_event_loop().run_until_complete( + adapter._download_attachment("att-guid", {"mimeType": "image/png"}) + ) + assert result is None diff --git a/tests/gateway/test_discord_document_handling.py b/tests/gateway/test_discord_document_handling.py index 7f918d1c73..a22e0f0d66 100644 --- a/tests/gateway/test_discord_document_handling.py +++ b/tests/gateway/test_discord_document_handling.py @@ -209,14 +209,31 @@ class TestIncomingDocumentHandling: assert "[Content of readme.md]:" in event.text assert "# Title" in event.text + @pytest.mark.asyncio + async def test_log_content_injected(self, adapter): + """.log file under 100KB should be treated as text/plain and injected.""" + file_content = b"BLE trace line 1\nBLE trace line 2" + + with _mock_aiohttp_download(file_content): + msg = make_message( + attachments=[make_attachment(filename="btsnoop_hci.log", content_type="text/plain")], + content="please inspect this", + ) + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert "[Content of btsnoop_hci.log]:" in event.text + assert "BLE trace line 1" in event.text + assert "please inspect this" in event.text + @pytest.mark.asyncio async def test_oversized_document_skipped(self, adapter): - """A document over 20MB should be skipped — media_urls stays empty.""" + """A document over 32MB should be skipped — media_urls stays empty.""" msg = make_message([ make_attachment( filename="huge.pdf", content_type="application/pdf", - size=25 * 1024 * 1024, + size=33 * 1024 * 1024, ) ]) await adapter._handle_message(msg) @@ -226,6 +243,24 @@ class TestIncomingDocumentHandling: # handler must still be called adapter.handle_message.assert_called_once() + @pytest.mark.asyncio + async def test_mid_sized_zip_under_32mb_is_cached(self, adapter): + """A 25MB .zip should be accepted now that Discord documents allow up to 32MB.""" + msg = make_message([ + make_attachment( + filename="bugreport.zip", + content_type="application/zip", + size=25 * 1024 * 1024, + ) + ]) + + with _mock_aiohttp_download(b"PK\x03\x04test"): + await adapter._handle_message(msg) + + event = adapter.handle_message.call_args[0][0] + assert len(event.media_urls) == 1 + assert event.media_types == ["application/zip"] + @pytest.mark.asyncio async def test_zip_document_cached(self, adapter): """A .zip file should be cached as a supported document.""" diff --git a/tests/gateway/test_gateway_inactivity_timeout.py b/tests/gateway/test_gateway_inactivity_timeout.py new file mode 100644 index 0000000000..598f33817c --- /dev/null +++ b/tests/gateway/test_gateway_inactivity_timeout.py @@ -0,0 +1,315 @@ +"""Tests for staged inactivity timeout in gateway agent runs. + +Tests cover: +- Warning fires once when inactivity reaches gateway_timeout_warning threshold +- Warning does not fire when gateway_timeout is 0 (unlimited) +- Warning fires only once per run, not on every poll +- Full timeout still fires at gateway_timeout threshold +- Warning respects HERMES_AGENT_TIMEOUT_WARNING env var +- Warning disabled when gateway_timeout_warning is 0 +""" + +import concurrent.futures +import os +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +class FakeAgent: + """Mock agent with controllable activity summary for timeout tests.""" + + def __init__(self, idle_seconds=0.0, activity_desc="tool_call", + current_tool=None, api_call_count=5, max_iterations=90): + self._idle_seconds = idle_seconds + self._activity_desc = activity_desc + self._current_tool = current_tool + self._api_call_count = api_call_count + self._max_iterations = max_iterations + self._interrupted = False + self._interrupt_msg = None + + def get_activity_summary(self): + return { + "last_activity_ts": time.time() - self._idle_seconds, + "last_activity_desc": self._activity_desc, + "seconds_since_activity": self._idle_seconds, + "current_tool": self._current_tool, + "api_call_count": self._api_call_count, + "max_iterations": self._max_iterations, + } + + def interrupt(self, msg): + self._interrupted = True + self._interrupt_msg = msg + + def run_conversation(self, prompt): + return {"final_response": "Done", "messages": []} + + +class SlowFakeAgent(FakeAgent): + """Agent that runs for a while, then goes idle.""" + + def __init__(self, run_duration=0.5, idle_after=None, **kwargs): + super().__init__(**kwargs) + self._run_duration = run_duration + self._idle_after = idle_after + self._start_time = None + + def get_activity_summary(self): + summary = super().get_activity_summary() + if self._idle_after is not None and self._start_time: + elapsed = time.time() - self._start_time + if elapsed > self._idle_after: + idle_time = elapsed - self._idle_after + summary["seconds_since_activity"] = idle_time + summary["last_activity_desc"] = "api_call_streaming" + else: + summary["seconds_since_activity"] = 0.0 + return summary + + def run_conversation(self, prompt): + self._start_time = time.time() + time.sleep(self._run_duration) + return {"final_response": "Completed after work", "messages": []} + + +class TestStagedInactivityWarning: + """Test the staged inactivity warning before full timeout.""" + + def test_warning_fires_once_before_timeout(self): + """Warning fires when inactivity reaches warning threshold.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.1, + activity_desc="api_call_streaming", + ) + + _agent_timeout = 20.0 + _agent_warning = 5.0 + _POLL_INTERVAL = 0.1 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test prompt") + _inactivity_timeout = False + _warning_fired = False + _warning_send_count = 0 + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + result = future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + _warning_send_count += 1 + if _idle_secs >= _agent_timeout: + _inactivity_timeout = True + break + + pool.shutdown(wait=False, cancel_futures=True) + + assert _warning_fired + assert _warning_send_count == 1 + assert not _inactivity_timeout + + def test_warning_disabled_when_zero(self): + """No warning fires when gateway_timeout_warning is 0.""" + agent = SlowFakeAgent( + run_duration=5.0, + idle_after=0.1, + ) + + _agent_timeout = 20.0 + _agent_warning = 0.0 + _POLL_INTERVAL = 0.1 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + break + + pool.shutdown(wait=False, cancel_futures=True) + assert not _warning_fired + + def test_warning_fires_only_once(self): + """Warning fires exactly once even if agent remains idle.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.05, + ) + + _agent_timeout = 20.0 + _agent_warning = 0.2 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_count = 0 + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_count and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_count += 1 + if _idle_secs >= _agent_timeout: + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_count == 1 + + def test_full_timeout_still_fires_after_warning(self): + """Full timeout fires even after warning was sent.""" + agent = SlowFakeAgent( + run_duration=15.0, + idle_after=0.1, + activity_desc="waiting for provider response (streaming)", + ) + + _agent_timeout = 1.0 + _agent_warning = 0.3 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _inactivity_timeout = False + _warning_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + _inactivity_timeout = True + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_fired + assert _inactivity_timeout + + def test_warning_env_var_respected(self, monkeypatch): + """HERMES_AGENT_TIMEOUT_WARNING env var is parsed correctly.""" + monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "600") + _warning = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + assert _warning == 600.0 + + def test_warning_zero_means_disabled(self, monkeypatch): + """HERMES_AGENT_TIMEOUT_WARNING=0 disables the warning.""" + monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "0") + _raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900)) + _warning = _raw if _raw > 0 else None + assert _warning is None + + def test_unlimited_timeout_no_warning(self): + """When timeout is unlimited (0), no warning fires either.""" + agent = SlowFakeAgent( + run_duration=0.5, + idle_after=0.0, + ) + + _agent_timeout = None + _agent_warning = 5.0 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + + result = future.result(timeout=2.0) + pool.shutdown(wait=False) + + assert result["final_response"] == "Completed after work" + + +class TestWarningThresholdBelowTimeout: + """Test that warning threshold must be less than timeout threshold.""" + + def test_warning_at_half_timeout(self): + """Warning fires at half the timeout duration.""" + agent = SlowFakeAgent( + run_duration=10.0, + idle_after=0.1, + activity_desc="receiving stream response", + ) + + _agent_timeout = 2.0 + _agent_warning = 1.0 + _POLL_INTERVAL = 0.05 + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(agent.run_conversation, "test") + _warning_fired = False + _timeout_fired = False + + while True: + done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL) + if done: + future.result() + break + _idle_secs = 0.0 + if hasattr(agent, "get_activity_summary"): + try: + _act = agent.get_activity_summary() + _idle_secs = _act.get("seconds_since_activity", 0.0) + except Exception: + pass + if (not _warning_fired and _agent_warning > 0 + and _idle_secs >= _agent_warning): + _warning_fired = True + if _idle_secs >= _agent_timeout: + _timeout_fired = True + break + + pool.shutdown(wait=False, cancel_futures=True) + assert _warning_fired + assert _timeout_fired diff --git a/tests/gateway/test_internal_event_bypass_pairing.py b/tests/gateway/test_internal_event_bypass_pairing.py new file mode 100644 index 0000000000..19ecd7059e --- /dev/null +++ b/tests/gateway/test_internal_event_bypass_pairing.py @@ -0,0 +1,226 @@ +"""Tests that internal synthetic events (e.g. background process completion) +bypass user authorization and do not trigger DM pairing. + +Regression test for the bug where ``_run_process_watcher`` with +``notify_on_complete=True`` injected a ``MessageEvent`` without ``user_id``, +causing ``_is_user_authorized`` to reject it and the gateway to send a +pairing code to the chat. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from gateway.config import GatewayConfig, Platform +from gateway.platforms.base import MessageEvent +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _FakeRegistry: + """Return pre-canned sessions, then None once exhausted.""" + + def __init__(self, sessions): + self._sessions = list(sessions) + + def get(self, session_id): + if self._sessions: + return self._sessions.pop(0) + return None + + +def _build_runner(monkeypatch, tmp_path) -> GatewayRunner: + """Create a GatewayRunner with notifications set to 'all'.""" + (tmp_path / "config.yaml").write_text( + "display:\n background_process_notifications: all\n", + encoding="utf-8", + ) + + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + return runner + + +def _watcher_dict_with_notify(): + return { + "session_id": "proc_test_internal", + "check_interval": 0, + "session_key": "agent:main:discord:dm:123", + "platform": "discord", + "chat_id": "123", + "thread_id": "", + "notify_on_complete": True, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_notify_on_complete_sets_internal_flag(monkeypatch, tmp_path): + """Synthetic completion event must have internal=True.""" + import tools.process_registry as pr_module + + sessions = [ + SimpleNamespace( + output_buffer="done\n", exited=True, exit_code=0, command="echo test" + ), + ] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path) + adapter = runner.adapters[Platform.DISCORD] + + await runner._run_process_watcher(_watcher_dict_with_notify()) + + assert adapter.handle_message.await_count == 1 + event = adapter.handle_message.await_args.args[0] + assert isinstance(event, MessageEvent) + assert event.internal is True, "Synthetic completion event must be marked internal" + + +@pytest.mark.asyncio +async def test_internal_event_bypasses_authorization(monkeypatch, tmp_path): + """An internal event should skip _is_user_authorized entirely.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + + # Create an internal event with no user_id (simulates the bug scenario) + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="[SYSTEM: Background process completed]", + source=source, + internal=True, + ) + + # Track if _is_user_authorized is called + auth_called = False + original_auth = GatewayRunner._is_user_authorized + + def tracking_auth(self, src): + nonlocal auth_called + auth_called = True + return original_auth(self, src) + + monkeypatch.setattr(GatewayRunner, "_is_user_authorized", tracking_auth) + + # _handle_message will proceed past auth check and eventually fail on + # downstream logic. We just need to verify auth is skipped. + try: + await runner._handle_message(event) + except Exception: + pass # Expected — downstream code needs more setup + + assert not auth_called, ( + "_is_user_authorized should NOT be called for internal events" + ) + + +@pytest.mark.asyncio +async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path): + """An internal event with no user_id must not generate a pairing code.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + runner = GatewayRunner(GatewayConfig()) + # Add adapter so pairing would have somewhere to send + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", # DM would normally trigger pairing + ) + event = MessageEvent( + text="[SYSTEM: Background process completed]", + source=source, + internal=True, + ) + + # Track pairing code generation + generate_called = False + original_generate = runner.pairing_store.generate_code + + def tracking_generate(*args, **kwargs): + nonlocal generate_called + generate_called = True + return original_generate(*args, **kwargs) + + runner.pairing_store.generate_code = tracking_generate + + try: + await runner._handle_message(event) + except Exception: + pass # Expected — downstream code needs more setup + + assert not generate_called, ( + "Pairing code should NOT be generated for internal events" + ) + + +@pytest.mark.asyncio +async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path): + """Verify the normal (non-internal) path still triggers pairing for unknown users.""" + import gateway.run as gateway_run + + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + (tmp_path / "config.yaml").write_text("", encoding="utf-8") + + # Clear env vars that could let all users through (loaded by + # module-level dotenv in gateway/run.py from the real ~/.hermes/.env). + monkeypatch.delenv("DISCORD_ALLOW_ALL_USERS", raising=False) + monkeypatch.delenv("DISCORD_ALLOWED_USERS", raising=False) + monkeypatch.delenv("GATEWAY_ALLOW_ALL_USERS", raising=False) + monkeypatch.delenv("GATEWAY_ALLOWED_USERS", raising=False) + + runner = GatewayRunner(GatewayConfig()) + adapter = SimpleNamespace(send=AsyncMock()) + runner.adapters[Platform.DISCORD] = adapter + + source = SessionSource( + platform=Platform.DISCORD, + chat_id="123", + chat_type="dm", + user_id="unknown_user_999", + ) + # Normal event (not internal) + event = MessageEvent( + text="hello", + source=source, + internal=False, + ) + + result = await runner._handle_message(event) + + # Should return None (unauthorized) and send pairing message + assert result is None + assert adapter.send.await_count == 1 + sent_text = adapter.send.await_args.args[1] + assert "don't recognize you" in sent_text diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index b2830e1fcd..ae985300d1 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -707,3 +707,66 @@ class TestSignalSendDocumentViaHelper: assert result.success is False assert "/nonexistent.pdf" in result.error + + +# --------------------------------------------------------------------------- +# send() returns message_id from timestamp (#4647) +# --------------------------------------------------------------------------- + +class TestSignalSendReturnsMessageId: + """Signal send() must return a timestamp-based message_id so the stream + consumer can follow its edit→fallback path correctly.""" + + @pytest.mark.asyncio + async def test_send_returns_timestamp_as_message_id(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000}) + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id == "1712345678000" + + @pytest.mark.asyncio + async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc({}) # No timestamp key + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id is None + + @pytest.mark.asyncio + async def test_send_returns_none_message_id_for_non_dict(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + mock_rpc, _ = _stub_rpc("ok") # Non-dict result + adapter._rpc = mock_rpc + adapter._stop_typing_indicator = AsyncMock() + + result = await adapter.send(chat_id="+155****4567", content="hello") + + assert result.success is True + assert result.message_id is None + + +# --------------------------------------------------------------------------- +# stop_typing() delegates to _stop_typing_indicator (#4647) +# --------------------------------------------------------------------------- + +class TestSignalStopTyping: + """Signal must expose a public stop_typing() so base adapter's + _keep_typing finally block can clean up platform-level typing tasks.""" + + @pytest.mark.asyncio + async def test_stop_typing_calls_private_method(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + adapter._stop_typing_indicator = AsyncMock() + + await adapter.stop_typing("+155****4567") + + adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567") diff --git a/tests/gateway/test_slack.py b/tests/gateway/test_slack.py index 89b4471834..67c7cce1dc 100644 --- a/tests/gateway/test_slack.py +++ b/tests/gateway/test_slack.py @@ -96,7 +96,7 @@ class TestAppMentionHandler: """Verify that the app_mention event handler is registered.""" def test_app_mention_registered_on_connect(self): - """connect() should register both 'message' and 'app_mention' handlers.""" + """connect() should register message + assistant lifecycle handlers.""" config = PlatformConfig(enabled=True, token="xoxb-fake") adapter = SlackAdapter(config) @@ -145,6 +145,8 @@ class TestAppMentionHandler: assert "message" in registered_events assert "app_mention" in registered_events + assert "assistant_thread_started" in registered_events + assert "assistant_thread_context_changed" in registered_events assert "/hermes" in registered_commands @@ -840,6 +842,114 @@ class TestThreadReplyHandling: adapter.handle_message.assert_not_called() +# --------------------------------------------------------------------------- +# TestAssistantThreadLifecycle +# --------------------------------------------------------------------------- + + +class TestAssistantThreadLifecycle: + """Slack Assistant lifecycle events should seed session/user context.""" + + @pytest.fixture() + def mock_session_store(self): + store = MagicMock() + store._entries = {} + store._ensure_loaded = MagicMock() + store.config = MagicMock() + store.config.group_sessions_per_user = True + store.get_or_create_session = MagicMock() + return store + + @pytest.fixture() + def assistant_adapter(self, mock_session_store): + config = PlatformConfig(enabled=True, token="***") + a = SlackAdapter(config) + a._app = MagicMock() + a._app.client = AsyncMock() + a._bot_user_id = "U_BOT" + a._team_bot_user_ids = {"T_TEAM": "U_BOT"} + a._running = True + a.handle_message = AsyncMock() + a.set_session_store(mock_session_store) + return a + + @pytest.mark.asyncio + async def test_lifecycle_event_seeds_session_store(self, assistant_adapter, mock_session_store): + event = { + "type": "assistant_thread_started", + "team_id": "T_TEAM", + "assistant_thread": { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "context": {"channel_id": "C_ORIGIN"}, + }, + } + + await assistant_adapter._handle_assistant_thread_lifecycle_event(event) + + assert assistant_adapter._assistant_threads[("D123", "171.000")]["user_id"] == "U_USER" + mock_session_store.get_or_create_session.assert_called_once() + source = mock_session_store.get_or_create_session.call_args[0][0] + assert source.chat_id == "D123" + assert source.chat_type == "dm" + assert source.user_id == "U_USER" + assert source.thread_id == "171.000" + assert source.chat_topic == "C_ORIGIN" + + @pytest.mark.asyncio + async def test_message_uses_cached_assistant_thread_identity(self, assistant_adapter): + assistant_adapter._assistant_threads[("D123", "171.000")] = { + "channel_id": "D123", + "thread_ts": "171.000", + "user_id": "U_USER", + "team_id": "T_TEAM", + } + assistant_adapter._app.client.users_info = AsyncMock(return_value={ + "user": {"profile": {"display_name": "Tyler"}} + }) + assistant_adapter._app.client.reactions_add = AsyncMock() + assistant_adapter._app.client.reactions_remove = AsyncMock() + + event = { + "text": "hello from assistant dm", + "channel": "D123", + "channel_type": "im", + "thread_ts": "171.000", + "ts": "171.111", + "team": "T_TEAM", + } + + await assistant_adapter._handle_slack_message(event) + + msg_event = assistant_adapter.handle_message.call_args[0][0] + assert msg_event.source.user_id == "U_USER" + assert msg_event.source.thread_id == "171.000" + assert msg_event.source.user_name == "Tyler" + + def test_assistant_threads_cache_eviction(self, assistant_adapter): + """Cache should evict oldest entries when exceeding the size limit.""" + assistant_adapter._ASSISTANT_THREADS_MAX = 10 + # Fill to the limit + for i in range(10): + assistant_adapter._cache_assistant_thread_metadata({ + "channel_id": f"D{i}", + "thread_ts": f"{i}.000", + "user_id": f"U{i}", + }) + assert len(assistant_adapter._assistant_threads) == 10 + + # Adding one more should trigger eviction (down to max // 2 = 5) + assistant_adapter._cache_assistant_thread_metadata({ + "channel_id": "D999", + "thread_ts": "999.000", + "user_id": "U999", + }) + assert len(assistant_adapter._assistant_threads) <= 10 + # The newest entry must survive eviction + assert ("D999", "999.000") in assistant_adapter._assistant_threads + + # --------------------------------------------------------------------------- # TestUserNameResolution # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index ddc88fc2fc..d5a20331b6 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -383,6 +383,60 @@ class TestSegmentBreakOnToolBoundary: sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] assert sent_texts == ["Hello ▉", "Next segment"] + @pytest.mark.asyncio + async def test_no_message_id_enters_fallback_mode(self): + """Platform returns success but no message_id (Signal) — must not + re-send on every delta. Should enter fallback mode and send only + the continuation at finish.""" + adapter = MagicMock() + # First send succeeds but returns no message_id (Signal behavior) + send_result_no_id = SimpleNamespace(success=True, message_id=None) + # Fallback final send succeeds + send_result_final = SimpleNamespace(success=True, message_id="msg_final") + adapter.send = AsyncMock(side_effect=[send_result_no_id, send_result_final]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Hello") + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(" world, this is a longer response.") + await asyncio.sleep(0.08) + consumer.finish() + await task + + # Should send exactly 2 messages: initial chunk + fallback continuation + # NOT one message per delta + assert adapter.send.call_count == 2 + assert consumer.already_sent + # edit_message should NOT have been called (no valid message_id to edit) + adapter.edit_message.assert_not_called() + + @pytest.mark.asyncio + async def test_no_message_id_single_delta_marks_already_sent(self): + """When the entire response fits in one delta and platform returns no + message_id, already_sent must still be True to prevent the gateway + from re-sending the full response.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id=None) + adapter.send = AsyncMock(return_value=send_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Short response.") + consumer.finish() + + await consumer.run() + + assert consumer.already_sent + # Only one send call (the initial message) + assert adapter.send.call_count == 1 + @pytest.mark.asyncio async def test_fallback_final_splits_long_continuation_without_dropping_text(self): """Long continuation tails should be chunked when fallback final-send runs.""" diff --git a/tests/hermes_cli/test_model_switch_variant_tags.py b/tests/hermes_cli/test_model_switch_variant_tags.py new file mode 100644 index 0000000000..eebb5dc139 --- /dev/null +++ b/tests/hermes_cli/test_model_switch_variant_tags.py @@ -0,0 +1,70 @@ +"""Tests for OpenRouter variant tag preservation in model switching. + +Regression test for GitHub PR #6088 / Discord report: OpenRouter model IDs +with variant suffixes like ``:free``, ``:extended``, ``:fast`` were being +mangled by the colon-to-slash conversion in model_switch.py Step c. + +The fix: Step c now skips colon→slash conversion when the model name already +contains a forward slash (i.e. is already in ``vendor/model`` format), since +the colon is a variant tag, not a vendor separator. +""" +import pytest +from unittest.mock import patch + +from hermes_cli.model_switch import switch_model + + +# Shared mock context — skip network calls, credential resolution, catalog lookups +_MOCK_VALIDATION = {"accepted": True, "persist": True, "recognized": True, "message": None} + + +def _run_switch(raw_input: str, current_provider: str = "openrouter") -> str: + """Run switch_model with mocked dependencies, return the resolved model name.""" + with patch("hermes_cli.model_switch.resolve_alias", return_value=None), \ + patch("hermes_cli.model_switch.list_provider_models", return_value=[]), \ + patch("hermes_cli.runtime_provider.resolve_runtime_provider", + return_value={"api_key": "test", "base_url": "", "api_mode": "chat_completions"}), \ + patch("hermes_cli.models.validate_requested_model", return_value=_MOCK_VALIDATION), \ + patch("hermes_cli.model_switch.get_model_info", return_value=None), \ + patch("hermes_cli.model_switch.get_model_capabilities", return_value=None), \ + patch("hermes_cli.models.detect_provider_for_model", return_value=None): + result = switch_model( + raw_input=raw_input, + current_provider=current_provider, + current_model="anthropic/claude-sonnet-4.6", + ) + assert result.success, f"switch_model failed: {result.error_message}" + return result.new_model + + +class TestVariantTagPreservation: + """OpenRouter variant tags (:free, :extended, :fast) must survive model switching.""" + + @pytest.mark.parametrize("model,expected", [ + ("nvidia/nemotron-3-super-120b-a12b:free", "nvidia/nemotron-3-super-120b-a12b:free"), + ("anthropic/claude-sonnet-4.6:extended", "anthropic/claude-sonnet-4.6:extended"), + ("meta-llama/llama-4-maverick:fast", "meta-llama/llama-4-maverick:fast"), + ]) + def test_slash_format_preserves_variant_tag(self, model, expected): + """Models already in vendor/model:tag format must not have their tag mangled.""" + assert _run_switch(model) == expected + + def test_legacy_colon_format_converts_to_slash(self): + """Legacy vendor:model (no slash) should still be converted to vendor/model.""" + result = _run_switch("nvidia:nemotron-3-super-120b-a12b") + assert result == "nvidia/nemotron-3-super-120b-a12b" + + def test_legacy_colon_format_with_tag_converts_first_colon_only(self): + """vendor:model:free (no slash) → vendor/model:free — first colon becomes slash.""" + result = _run_switch("nvidia:nemotron-3-super-120b-a12b:free") + assert result == "nvidia/nemotron-3-super-120b-a12b:free" + + def test_bare_model_name_unaffected(self): + """Bare model names without colons or slashes should work normally.""" + result = _run_switch("claude-sonnet-4.6") + assert result == "anthropic/claude-sonnet-4.6" + + def test_already_correct_slug_no_tag(self): + """Standard vendor/model slugs without tags pass through unchanged.""" + result = _run_switch("anthropic/claude-sonnet-4.6") + assert result == "anthropic/claude-sonnet-4.6" diff --git a/tests/plugins/memory/test_hindsight_provider.py b/tests/plugins/memory/test_hindsight_provider.py new file mode 100644 index 0000000000..5548a29ad4 --- /dev/null +++ b/tests/plugins/memory/test_hindsight_provider.py @@ -0,0 +1,598 @@ +"""Tests for the Hindsight memory provider plugin. + +Tests cover config loading, tool handlers (tags, max_tokens, types), +prefetch (auto_recall, preamble, query truncation), sync_turn (auto_retain, +turn counting, tags), and schema completeness. +""" + +import json +import threading +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from plugins.memory.hindsight import ( + HindsightMemoryProvider, + RECALL_SCHEMA, + REFLECT_SCHEMA, + RETAIN_SCHEMA, + _load_config, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """Ensure no stale env vars leak between tests.""" + for key in ( + "HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID", + "HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + +def _make_mock_client(): + """Create a mock Hindsight client with async methods.""" + client = MagicMock() + client.aretain = AsyncMock() + client.arecall = AsyncMock( + return_value=SimpleNamespace( + results=[ + SimpleNamespace(text="Memory 1"), + SimpleNamespace(text="Memory 2"), + ] + ) + ) + client.areflect = AsyncMock( + return_value=SimpleNamespace(text="Synthesized answer") + ) + client.aretain_batch = AsyncMock() + client.aclose = AsyncMock() + return client + + +@pytest.fixture() +def provider(tmp_path, monkeypatch): + """Create an initialized HindsightMemoryProvider with a mock client.""" + config = { + "mode": "cloud", + "apiKey": "test-key", + "api_url": "http://localhost:9999", + "bank_id": "test-bank", + "budget": "mid", + "memory_mode": "hybrid", + } + config_path = tmp_path / "hindsight" / "config.json" + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", lambda: tmp_path + ) + + p = HindsightMemoryProvider() + p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli") + p._client = _make_mock_client() + return p + + +@pytest.fixture() +def provider_with_config(tmp_path, monkeypatch): + """Create a provider factory that accepts custom config overrides.""" + def _make(**overrides): + config = { + "mode": "cloud", + "apiKey": "test-key", + "api_url": "http://localhost:9999", + "bank_id": "test-bank", + "budget": "mid", + "memory_mode": "hybrid", + } + config.update(overrides) + config_path = tmp_path / "hindsight" / "config.json" + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text(json.dumps(config)) + + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", lambda: tmp_path + ) + + p = HindsightMemoryProvider() + p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli") + p._client = _make_mock_client() + return p + return _make + + +# --------------------------------------------------------------------------- +# Schema tests +# --------------------------------------------------------------------------- + + +class TestSchemas: + def test_retain_schema_has_content(self): + assert RETAIN_SCHEMA["name"] == "hindsight_retain" + assert "content" in RETAIN_SCHEMA["parameters"]["properties"] + assert "content" in RETAIN_SCHEMA["parameters"]["required"] + + def test_recall_schema_has_query(self): + assert RECALL_SCHEMA["name"] == "hindsight_recall" + assert "query" in RECALL_SCHEMA["parameters"]["properties"] + assert "query" in RECALL_SCHEMA["parameters"]["required"] + + def test_reflect_schema_has_query(self): + assert REFLECT_SCHEMA["name"] == "hindsight_reflect" + assert "query" in REFLECT_SCHEMA["parameters"]["properties"] + + def test_get_tool_schemas_returns_three(self, provider): + schemas = provider.get_tool_schemas() + assert len(schemas) == 3 + names = {s["name"] for s in schemas} + assert names == {"hindsight_retain", "hindsight_recall", "hindsight_reflect"} + + def test_context_mode_returns_no_tools(self, provider_with_config): + p = provider_with_config(memory_mode="context") + assert p.get_tool_schemas() == [] + + +# --------------------------------------------------------------------------- +# Config tests +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_default_values(self, provider): + assert provider._auto_retain is True + assert provider._auto_recall is True + assert provider._retain_every_n_turns == 1 + assert provider._recall_max_tokens == 4096 + assert provider._recall_max_input_chars == 800 + assert provider._tags is None + assert provider._recall_tags is None + assert provider._bank_mission == "" + assert provider._bank_retain_mission is None + assert provider._retain_context == "conversation between Hermes Agent and the User" + + def test_custom_config_values(self, provider_with_config): + p = provider_with_config( + tags=["tag1", "tag2"], + recall_tags=["recall-tag"], + recall_tags_match="all", + auto_retain=False, + auto_recall=False, + retain_every_n_turns=3, + retain_context="custom-ctx", + bank_retain_mission="Extract key facts", + recall_max_tokens=2048, + recall_types=["world", "experience"], + recall_prompt_preamble="Custom preamble:", + recall_max_input_chars=500, + bank_mission="Test agent mission", + ) + assert p._tags == ["tag1", "tag2"] + assert p._recall_tags == ["recall-tag"] + assert p._recall_tags_match == "all" + assert p._auto_retain is False + assert p._auto_recall is False + assert p._retain_every_n_turns == 3 + assert p._retain_context == "custom-ctx" + assert p._bank_retain_mission == "Extract key facts" + assert p._recall_max_tokens == 2048 + assert p._recall_types == ["world", "experience"] + assert p._recall_prompt_preamble == "Custom preamble:" + assert p._recall_max_input_chars == 500 + assert p._bank_mission == "Test agent mission" + + def test_config_from_env_fallback(self, tmp_path, monkeypatch): + """When no config file exists, falls back to env vars.""" + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_MODE", "cloud") + monkeypatch.setenv("HINDSIGHT_API_KEY", "env-key") + monkeypatch.setenv("HINDSIGHT_BANK_ID", "env-bank") + monkeypatch.setenv("HINDSIGHT_BUDGET", "high") + + cfg = _load_config() + assert cfg["apiKey"] == "env-key" + assert cfg["banks"]["hermes"]["bankId"] == "env-bank" + assert cfg["banks"]["hermes"]["budget"] == "high" + + +# --------------------------------------------------------------------------- +# Tool handler tests +# --------------------------------------------------------------------------- + + +class TestToolHandlers: + def test_retain_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {"content": "user likes dark mode"} + )) + assert result["result"] == "Memory stored successfully." + provider._client.aretain.assert_called_once() + call_kwargs = provider._client.aretain.call_args.kwargs + assert call_kwargs["bank_id"] == "test-bank" + assert call_kwargs["content"] == "user likes dark mode" + + def test_retain_with_tags(self, provider_with_config): + p = provider_with_config(tags=["pref", "ui"]) + p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"}) + call_kwargs = p._client.aretain.call_args.kwargs + assert call_kwargs["tags"] == ["pref", "ui"] + + def test_retain_without_tags(self, provider): + provider.handle_tool_call("hindsight_retain", {"content": "hello"}) + call_kwargs = provider._client.aretain.call_args.kwargs + assert "tags" not in call_kwargs + + def test_retain_missing_content(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {} + )) + assert "error" in result + + def test_recall_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "dark mode"} + )) + assert "Memory 1" in result["result"] + assert "Memory 2" in result["result"] + + def test_recall_passes_max_tokens(self, provider_with_config): + p = provider_with_config(recall_max_tokens=2048) + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["max_tokens"] == 2048 + + def test_recall_passes_tags(self, provider_with_config): + p = provider_with_config(recall_tags=["tag1"], recall_tags_match="all") + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["tags"] == ["tag1"] + assert call_kwargs["tags_match"] == "all" + + def test_recall_passes_types(self, provider_with_config): + p = provider_with_config(recall_types=["world", "experience"]) + p.handle_tool_call("hindsight_recall", {"query": "test"}) + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["types"] == ["world", "experience"] + + def test_recall_no_results(self, provider): + provider._client.arecall.return_value = SimpleNamespace(results=[]) + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "test"} + )) + assert result["result"] == "No relevant memories found." + + def test_recall_missing_query(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {} + )) + assert "error" in result + + def test_reflect_success(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_reflect", {"query": "summarize"} + )) + assert result["result"] == "Synthesized answer" + + def test_reflect_missing_query(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_reflect", {} + )) + assert "error" in result + + def test_unknown_tool(self, provider): + result = json.loads(provider.handle_tool_call( + "hindsight_unknown", {} + )) + assert "error" in result + + def test_retain_error_handling(self, provider): + provider._client.aretain.side_effect = RuntimeError("connection failed") + result = json.loads(provider.handle_tool_call( + "hindsight_retain", {"content": "test"} + )) + assert "error" in result + assert "connection failed" in result["error"] + + def test_recall_error_handling(self, provider): + provider._client.arecall.side_effect = RuntimeError("timeout") + result = json.loads(provider.handle_tool_call( + "hindsight_recall", {"query": "test"} + )) + assert "error" in result + + +# --------------------------------------------------------------------------- +# Prefetch tests +# --------------------------------------------------------------------------- + + +class TestPrefetch: + def test_prefetch_returns_empty_when_no_result(self, provider): + assert provider.prefetch("test") == "" + + def test_prefetch_default_preamble(self, provider): + provider._prefetch_result = "- some memory" + result = provider.prefetch("test") + assert "Hindsight Memory" in result + assert "- some memory" in result + + def test_prefetch_custom_preamble(self, provider_with_config): + p = provider_with_config(recall_prompt_preamble="Custom header:") + p._prefetch_result = "- memory line" + result = p.prefetch("test") + assert result.startswith("Custom header:") + assert "- memory line" in result + + def test_queue_prefetch_skipped_in_tools_mode(self, provider_with_config): + p = provider_with_config(memory_mode="tools") + p.queue_prefetch("test") + # Should not start a thread + assert p._prefetch_thread is None + + def test_queue_prefetch_skipped_when_auto_recall_off(self, provider_with_config): + p = provider_with_config(auto_recall=False) + p.queue_prefetch("test") + assert p._prefetch_thread is None + + def test_queue_prefetch_truncates_query(self, provider_with_config): + p = provider_with_config(recall_max_input_chars=10) + # Mock _run_sync to capture the query + original_query = None + + def _capture_recall(**kwargs): + nonlocal original_query + original_query = kwargs.get("query", "") + return SimpleNamespace(results=[]) + + p._client.arecall = AsyncMock(side_effect=_capture_recall) + + long_query = "a" * 100 + p.queue_prefetch(long_query) + if p._prefetch_thread: + p._prefetch_thread.join(timeout=5.0) + + # The query passed to arecall should be truncated + if original_query is not None: + assert len(original_query) <= 10 + + def test_queue_prefetch_passes_recall_params(self, provider_with_config): + p = provider_with_config( + recall_tags=["t1"], + recall_tags_match="all", + recall_max_tokens=1024, + recall_types=["world"], + ) + p.queue_prefetch("test query") + if p._prefetch_thread: + p._prefetch_thread.join(timeout=5.0) + + call_kwargs = p._client.arecall.call_args.kwargs + assert call_kwargs["max_tokens"] == 1024 + assert call_kwargs["tags"] == ["t1"] + assert call_kwargs["tags_match"] == "all" + assert call_kwargs["types"] == ["world"] + + +# --------------------------------------------------------------------------- +# sync_turn tests +# --------------------------------------------------------------------------- + + +class TestSyncTurn: + def _get_retain_kwargs(self, provider): + """Helper to get the kwargs from the aretain_batch call.""" + return provider._client.aretain_batch.call_args.kwargs + + def _get_retain_content(self, provider): + """Helper to get the raw content string from the first item.""" + kwargs = self._get_retain_kwargs(provider) + return kwargs["items"][0]["content"] + + def _get_retain_messages(self, provider): + """Helper to parse the first turn's messages from retained content. + + Content is a JSON array of turns: [[msgs...], [msgs...], ...] + For single-turn tests, returns the first turn's messages. + """ + content = self._get_retain_content(provider) + turns = json.loads(content) + return turns[0] if len(turns) == 1 else turns + + def test_sync_turn_retains(self, provider): + provider.sync_turn("hello", "hi there") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + provider._client.aretain_batch.assert_called_once() + messages = self._get_retain_messages(provider) + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "hello" + assert "timestamp" in messages[0] + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "hi there" + assert "timestamp" in messages[1] + + def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config): + p = provider_with_config(auto_retain=False) + p.sync_turn("hello", "hi") + assert p._sync_thread is None + p._client.aretain_batch.assert_not_called() + + def test_sync_turn_with_tags(self, provider_with_config): + p = provider_with_config(tags=["conv", "session1"]) + p.sync_turn("hello", "hi") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + item = p._client.aretain_batch.call_args.kwargs["items"][0] + assert item["tags"] == ["conv", "session1"] + + def test_sync_turn_uses_aretain_batch(self, provider): + """sync_turn should use aretain_batch with retain_async.""" + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + provider._client.aretain_batch.assert_called_once() + call_kwargs = provider._client.aretain_batch.call_args.kwargs + assert call_kwargs["document_id"] == "test-session" + assert call_kwargs["retain_async"] is True + assert len(call_kwargs["items"]) == 1 + assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User" + + def test_sync_turn_custom_context(self, provider_with_config): + p = provider_with_config(retain_context="my-agent") + p.sync_turn("hello", "hi") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + item = p._client.aretain_batch.call_args.kwargs["items"][0] + assert item["context"] == "my-agent" + + def test_sync_turn_every_n_turns(self, provider_with_config): + """With retain_every_n_turns=3, only retains on every 3rd turn.""" + p = provider_with_config(retain_every_n_turns=3) + + p.sync_turn("turn1-user", "turn1-asst") + assert p._sync_thread is None # not retained yet + + p.sync_turn("turn2-user", "turn2-asst") + assert p._sync_thread is None # not retained yet + + p.sync_turn("turn3-user", "turn3-asst") + assert p._sync_thread is not None # retained! + p._sync_thread.join(timeout=5.0) + + p._client.aretain_batch.assert_called_once() + content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] + # Should contain all 3 turns + assert "turn1-user" in content + assert "turn2-user" in content + assert "turn3-user" in content + + def test_sync_turn_accumulates_full_session(self, provider_with_config): + """Each retain sends the ENTIRE session, not just the latest batch.""" + p = provider_with_config(retain_every_n_turns=2) + + p.sync_turn("turn1-user", "turn1-asst") + p.sync_turn("turn2-user", "turn2-asst") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + + p._client.aretain_batch.reset_mock() + + p.sync_turn("turn3-user", "turn3-asst") + p.sync_turn("turn4-user", "turn4-asst") + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + + content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] + # Should contain ALL turns from the session + assert "turn1-user" in content + assert "turn2-user" in content + assert "turn3-user" in content + assert "turn4-user" in content + + def test_sync_turn_passes_document_id(self, provider): + """sync_turn should pass session_id as document_id for dedup.""" + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + call_kwargs = provider._client.aretain_batch.call_args.kwargs + assert call_kwargs["document_id"] == "test-session" + + def test_sync_turn_error_does_not_raise(self, provider): + """Errors in sync_turn should be swallowed (non-blocking).""" + provider._client.aretain_batch.side_effect = RuntimeError("network error") + provider.sync_turn("hello", "hi") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + # Should not raise + + +# --------------------------------------------------------------------------- +# System prompt tests +# --------------------------------------------------------------------------- + + +class TestSystemPrompt: + def test_hybrid_mode_prompt(self, provider): + block = provider.system_prompt_block() + assert "Hindsight Memory" in block + assert "hindsight_recall" in block + assert "automatically injected" in block + + def test_context_mode_prompt(self, provider_with_config): + p = provider_with_config(memory_mode="context") + block = p.system_prompt_block() + assert "context mode" in block + assert "hindsight_recall" not in block + + def test_tools_mode_prompt(self, provider_with_config): + p = provider_with_config(memory_mode="tools") + block = p.system_prompt_block() + assert "tools mode" in block + assert "hindsight_recall" in block + + +# --------------------------------------------------------------------------- +# Config schema tests +# --------------------------------------------------------------------------- + + +class TestConfigSchema: + def test_schema_has_all_new_fields(self, provider): + schema = provider.get_config_schema() + keys = {f["key"] for f in schema} + expected_keys = { + "mode", "api_url", "api_key", "llm_provider", "llm_api_key", + "llm_model", "bank_id", "bank_mission", "bank_retain_mission", + "recall_budget", "memory_mode", "recall_prefetch_method", + "tags", "recall_tags", "recall_tags_match", + "auto_recall", "auto_retain", + "retain_every_n_turns", "retain_async", + "retain_context", + "recall_max_tokens", "recall_max_input_chars", + "recall_prompt_preamble", + } + assert expected_keys.issubset(keys), f"Missing: {expected_keys - keys}" + + +# --------------------------------------------------------------------------- +# Availability tests +# --------------------------------------------------------------------------- + + +class TestAvailability: + def test_available_with_api_key(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_API_KEY", "test-key") + p = HindsightMemoryProvider() + assert p.is_available() + + def test_not_available_without_config(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + p = HindsightMemoryProvider() + assert not p.is_available() + + def test_available_in_local_mode(self, tmp_path, monkeypatch): + monkeypatch.setattr( + "plugins.memory.hindsight.get_hermes_home", + lambda: tmp_path / "nonexistent", + ) + monkeypatch.setenv("HINDSIGHT_MODE", "local") + p = HindsightMemoryProvider() + assert p.is_available() diff --git a/tests/run_agent/test_context_pressure.py b/tests/run_agent/test_context_pressure.py index 522603fdb5..4140749c51 100644 --- a/tests/run_agent/test_context_pressure.py +++ b/tests/run_agent/test_context_pressure.py @@ -150,8 +150,8 @@ def agent(): class TestContextPressureFlags: """Context pressure warning flag tracking on AIAgent.""" - def test_flag_initialized_false(self, agent): - assert agent._context_pressure_warned is False + def test_flag_initialized_zero(self, agent): + assert agent._context_pressure_warned_at == 0.0 def test_emit_calls_status_callback(self, agent): """status_callback should be invoked with event type and message.""" @@ -210,7 +210,7 @@ class TestContextPressureFlags: def test_flag_reset_on_compression(self, agent): """After _compress_context, context pressure flag should reset.""" - agent._context_pressure_warned = True + agent._context_pressure_warned_at = 0.85 agent.compression_enabled = True agent.context_compressor = MagicMock() @@ -219,6 +219,7 @@ class TestContextPressureFlags: ] agent.context_compressor.context_length = 200_000 agent.context_compressor.threshold_tokens = 100_000 + agent.context_compressor.compression_count = 1 agent._todo_store = MagicMock() agent._todo_store.format_for_injection.return_value = None @@ -233,7 +234,7 @@ class TestContextPressureFlags: ] agent._compress_context(messages, "system prompt") - assert agent._context_pressure_warned is False + assert agent._context_pressure_warned_at == 0.0 def test_emit_callback_error_handled(self, agent): """If status_callback raises, it should be caught gracefully.""" @@ -246,3 +247,115 @@ class TestContextPressureFlags: # Should not raise agent._emit_context_pressure(0.85, compressor) + + def test_tiered_reemits_at_95(self, agent): + """Warning fires at 85%, then fires again when crossing 95%.""" + agent._context_pressure_warned_at = 0.85 + # Simulate crossing 95%: the tier (0.95) > warned_at (0.85) + assert 0.95 > agent._context_pressure_warned_at + # After emission at 95%, the tier should update + agent._context_pressure_warned_at = 0.95 + assert agent._context_pressure_warned_at == 0.95 + + def test_tiered_no_double_emit_at_same_level(self, agent): + """Once warned at 85%, further 85%+ readings don't re-warn.""" + agent._context_pressure_warned_at = 0.85 + # At 88%, tier is 0.85, which is NOT > warned_at (0.85) + _warn_tier = 0.85 if 0.88 >= 0.85 else 0.0 + assert not (_warn_tier > agent._context_pressure_warned_at) + + def test_flag_not_reset_when_compression_insufficient(self, agent): + """When compression can't drop below 85%, keep the flag set.""" + agent._context_pressure_warned_at = 0.85 + agent.compression_enabled = True + + agent.context_compressor = MagicMock() + agent.context_compressor.compress.return_value = [ + {"role": "user", "content": "Summary of conversation so far."} + ] + agent.context_compressor.context_length = 200 + # Use a small threshold so the tiny compressed output still + # represents >= 85% of it (prevents flag reset). + agent.context_compressor.threshold_tokens = 10 + agent.context_compressor.compression_count = 1 + agent.context_compressor.last_prompt_tokens = 0 + + agent._todo_store = MagicMock() + agent._todo_store.format_for_injection.return_value = None + agent._build_system_prompt = MagicMock(return_value="system prompt") + agent._cached_system_prompt = "old system prompt" + agent._session_db = None + + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi there"}, + ] + agent._compress_context(messages, "system prompt") + + # Post-compression is ~90% of threshold — flag should NOT reset + assert agent._context_pressure_warned_at == 0.85 + + +class TestContextPressureGatewayDedup: + """Class-level dedup prevents warning spam across AIAgent instances.""" + + def setup_method(self): + """Clear class-level dedup state between tests.""" + AIAgent._context_pressure_last_warned.clear() + + def test_second_instance_within_cooldown_suppressed(self): + """Same session, same tier, within cooldown — should be suppressed.""" + import time + sid = "test_session_dedup" + # Simulate first warning + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + # Second instance checking same tier within cooldown + _last = AIAgent._context_pressure_last_warned.get(sid) + _should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert not _should_warn + + def test_higher_tier_fires_despite_cooldown(self): + """Same session, higher tier — should fire even within cooldown.""" + import time + sid = "test_session_tier" + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + _last = AIAgent._context_pressure_last_warned.get(sid) + # 0.95 > 0.85 stored tier → should warn + _should_warn = _last is None or _last[0] < 0.95 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert _should_warn + + def test_warning_fires_after_cooldown_expires(self): + """Same session, same tier, after cooldown — should fire again.""" + import time + sid = "test_session_expired" + # Set a timestamp far in the past + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time() - AIAgent._CONTEXT_PRESSURE_COOLDOWN - 1) + _last = AIAgent._context_pressure_last_warned.get(sid) + _should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN + assert _should_warn + + def test_compression_clears_dedup(self): + """After compression drops below 85%, dedup entry should be cleared.""" + import time + sid = "test_session_clear" + AIAgent._context_pressure_last_warned[sid] = (0.85, time.time()) + assert sid in AIAgent._context_pressure_last_warned + # Simulate what _compress_context does on reset + AIAgent._context_pressure_last_warned.pop(sid, None) + assert sid not in AIAgent._context_pressure_last_warned + + def test_eviction_removes_stale_entries(self): + """Stale entries older than 2x cooldown should be evicted.""" + import time + _now = time.time() + AIAgent._context_pressure_last_warned = { + "fresh": (0.85, _now), + "stale": (0.85, _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 3), + } + _cutoff = _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 2 + AIAgent._context_pressure_last_warned = { + k: v for k, v in AIAgent._context_pressure_last_warned.items() + if v[1] > _cutoff + } + assert "fresh" in AIAgent._context_pressure_last_warned + assert "stale" not in AIAgent._context_pressure_last_warned diff --git a/tests/run_agent/test_flush_memories_codex.py b/tests/run_agent/test_flush_memories_codex.py index 3d12c9d3ea..b4b3c648e6 100644 --- a/tests/run_agent/test_flush_memories_codex.py +++ b/tests/run_agent/test_flush_memories_codex.py @@ -91,6 +91,61 @@ def _chat_response_with_memory_call(): ) +class TestFlushMemoriesRespectsConfigTimeout: + """flush_memories() must NOT hardcode timeout=30.0 — it should defer + to the config value via auxiliary.flush_memories.timeout.""" + + def test_auxiliary_path_omits_explicit_timeout(self, monkeypatch): + """When calling _call_llm, timeout should NOT be passed so that + _get_task_timeout('flush_memories') reads from config.""" + agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") + + mock_response = _chat_response_with_memory_call() + + with patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_call: + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "Note this"}, + ] + with patch("tools.memory_tool.memory_tool", return_value="Saved."): + agent.flush_memories(messages) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + # timeout must NOT be explicitly passed (so _get_task_timeout resolves it) + assert "timeout" not in call_kwargs.kwargs, ( + "flush_memories should not pass explicit timeout to _call_llm; " + "let _get_task_timeout('flush_memories') resolve from config" + ) + + def test_fallback_path_uses_config_timeout(self, monkeypatch): + """When auxiliary client is unavailable and we fall back to direct + OpenAI client, timeout should come from _get_task_timeout, not hardcoded.""" + agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") + agent.client = MagicMock() + agent.client.chat.completions.create.return_value = _chat_response_with_memory_call() + + custom_timeout = 180.0 + + with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")), \ + patch("agent.auxiliary_client._get_task_timeout", return_value=custom_timeout) as mock_gtt, \ + patch("tools.memory_tool.memory_tool", return_value="Saved."): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "Save this"}, + ] + agent.flush_memories(messages) + + mock_gtt.assert_called_once_with("flush_memories") + agent.client.chat.completions.create.assert_called_once() + call_kwargs = agent.client.chat.completions.create.call_args + assert call_kwargs.kwargs.get("timeout") == custom_timeout, ( + f"Expected timeout={custom_timeout} from config, got {call_kwargs.kwargs.get('timeout')}" + ) + + class TestFlushMemoriesUsesAuxiliaryClient: """When an auxiliary client is available, flush_memories should use it instead of self.client -- especially critical in Codex mode.""" diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 59f88601c5..98d799ae43 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1668,12 +1668,15 @@ class TestRunConversation: if roles[i] == "assistant" and roles[i + 1] == "assistant": raise AssertionError("Consecutive assistant messages found in history") - def test_truly_empty_response_accepted_without_retry(self, agent): - """Truly empty response (no content, no reasoning) should still complete with (empty).""" + def test_truly_empty_response_retries_3_times_then_empty(self, agent): + """Truly empty response (no content, no reasoning) retries 3 times then falls through to (empty).""" self._setup_agent(agent) agent.base_url = "http://127.0.0.1:1234/v1" empty_resp = _mock_response(content=None, finish_reason="stop") - agent.client.chat.completions.create.side_effect = [empty_resp] + # 4 responses: 1 original + 3 nudge retries, all empty + agent.client.chat.completions.create.side_effect = [ + empty_resp, empty_resp, empty_resp, empty_resp, + ] with ( patch.object(agent, "_persist_session"), patch.object(agent, "_save_trajectory"), @@ -1682,7 +1685,28 @@ class TestRunConversation: result = agent.run_conversation("answer me") assert result["completed"] is True assert result["final_response"] == "(empty)" - assert result["api_calls"] == 1 # no retries + assert result["api_calls"] == 4 # 1 original + 3 retries + + def test_truly_empty_response_succeeds_on_nudge(self, agent): + """Model produces content after being nudged for empty response.""" + self._setup_agent(agent) + agent.base_url = "http://127.0.0.1:1234/v1" + empty_resp = _mock_response(content=None, finish_reason="stop") + content_resp = _mock_response( + content="Here is the actual answer.", + finish_reason="stop", + ) + # 1 empty response, then model produces content on nudge + agent.client.chat.completions.create.side_effect = [empty_resp, content_resp] + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("answer me") + assert result["completed"] is True + assert result["final_response"] == "Here is the actual answer." + assert result["api_calls"] == 2 # 1 original + 1 nudge retry def test_nous_401_refreshes_after_remint_and_retries(self, agent): self._setup_agent(agent) diff --git a/tests/skills/test_openclaw_migration.py b/tests/skills/test_openclaw_migration.py index d4aa8f710e..99d126bed5 100644 --- a/tests/skills/test_openclaw_migration.py +++ b/tests/skills/test_openclaw_migration.py @@ -658,6 +658,47 @@ def test_workspace_agents_records_skip_when_missing(tmp_path: Path): assert wa_items[0]["status"] == "skipped" +def test_cron_store_is_archived_without_config_cron_section(tmp_path: Path): + """Bug fix: archive cron store even when openclaw.json has no top-level cron config.""" + mod = load_module() + source = tmp_path / ".openclaw" + target = tmp_path / ".hermes" + output_dir = target / "migration-report" + source.mkdir() + target.mkdir() + + (source / "openclaw.json").write_text(json.dumps({"channels": {}}), encoding="utf-8") + (source / "cron").mkdir(parents=True) + (source / "cron" / "jobs.json").write_text( + json.dumps({"version": 1, "jobs": [{"id": "job-1", "name": "demo"}]}), + encoding="utf-8", + ) + + migrator = mod.Migrator( + source_root=source, + target_root=target, + execute=True, + workspace_target=None, + overwrite=False, + migrate_secrets=False, + output_dir=output_dir, + selected_options={"cron-jobs"}, + ) + report = migrator.migrate() + + cron_items = [item for item in report["items"] if item["kind"] == "cron-jobs"] + archived_store = next( + (item for item in cron_items if item["destination"] and item["destination"].endswith("archive/cron-store")), + None, + ) + assert archived_store is not None + assert Path(archived_store["destination"]).joinpath("jobs.json").exists() + + notes_text = (output_dir / "MIGRATION_NOTES.md").read_text(encoding="utf-8") + assert "Run `hermes cron` to recreate scheduled tasks" in notes_text + assert "archive/cron-config.json" not in notes_text + + def test_skill_installs_cleanly_under_skills_guard(): skills_guard = load_skills_guard() result = skills_guard.scan_skill( diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index a0630858c8..5f9a16a529 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -663,6 +663,84 @@ class TestPruneSessions: assert db.get_session("old_cli") is None assert db.get_session("old_tg") is not None + def test_prune_with_multilevel_chain(self, db): + """Pruning old sessions orphans newer children instead of crashing on FK.""" + old_ts = time.time() - 200 * 86400 + recent_ts = time.time() - 10 * 86400 + + # Chain: A (old) -> B (old) -> C (recent) -> D (recent) + db.create_session(session_id="A", source="cli") + db.end_session("A", end_reason="compressed") + db.create_session(session_id="B", source="cli", parent_session_id="A") + db.end_session("B", end_reason="compressed") + db.create_session(session_id="C", source="cli", parent_session_id="B") + db.end_session("C", end_reason="compressed") + db.create_session(session_id="D", source="cli", parent_session_id="C") + db.end_session("D", end_reason="done") + + # Backdate A and B to be old; C and D stay recent + for sid, ts in [("A", old_ts), ("B", old_ts), ("C", recent_ts), ("D", recent_ts)]: + db._conn.execute( + "UPDATE sessions SET started_at = ? WHERE id = ?", (ts, sid) + ) + db._conn.commit() + + # Should not raise IntegrityError + pruned = db.prune_sessions(older_than_days=90) + assert pruned == 2 # only A and B + assert db.get_session("A") is None + assert db.get_session("B") is None + # C and D survive, C is orphaned (parent_session_id NULL) + c = db.get_session("C") + assert c is not None + assert c["parent_session_id"] is None + d = db.get_session("D") + assert d is not None + assert d["parent_session_id"] == "C" + + def test_prune_entire_old_chain(self, db): + """All sessions in a chain are old — entire chain is pruned.""" + old_ts = time.time() - 200 * 86400 + + db.create_session(session_id="X", source="cli") + db.end_session("X", end_reason="compressed") + db.create_session(session_id="Y", source="cli", parent_session_id="X") + db.end_session("Y", end_reason="compressed") + db.create_session(session_id="Z", source="cli", parent_session_id="Y") + db.end_session("Z", end_reason="done") + + for sid in ("X", "Y", "Z"): + db._conn.execute( + "UPDATE sessions SET started_at = ? WHERE id = ?", (old_ts, sid) + ) + db._conn.commit() + + pruned = db.prune_sessions(older_than_days=90) + assert pruned == 3 + for sid in ("X", "Y", "Z"): + assert db.get_session(sid) is None + + +class TestDeleteSessionOrphansChildren: + def test_delete_orphans_children(self, db): + """Deleting a parent session orphans its children.""" + db.create_session(session_id="parent", source="cli") + db.create_session(session_id="child", source="cli", parent_session_id="parent") + db.create_session(session_id="grandchild", source="cli", parent_session_id="child") + + # Should not raise IntegrityError + result = db.delete_session("parent") + assert result is True + assert db.get_session("parent") is None + # Child is orphaned, not deleted + child = db.get_session("child") + assert child is not None + assert child["parent_session_id"] is None + # Grandchild is untouched + grandchild = db.get_session("grandchild") + assert grandchild is not None + assert grandchild["parent_session_id"] == "child" + # ========================================================================= # Schema and WAL mode diff --git a/tests/tools/test_base_environment.py b/tests/tools/test_base_environment.py new file mode 100644 index 0000000000..913ad0387c --- /dev/null +++ b/tests/tools/test_base_environment.py @@ -0,0 +1,174 @@ +"""Tests for BaseEnvironment unified execution model. + +Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(), +init_session() failure handling, and the CWD marker contract. +""" + +import uuid +from unittest.mock import MagicMock + +from tools.environments.base import BaseEnvironment, _cwd_marker + + +class _TestableEnv(BaseEnvironment): + """Concrete subclass for testing base class methods.""" + + def __init__(self, cwd="/tmp", timeout=10): + super().__init__(cwd=cwd, timeout=timeout) + + def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None): + raise NotImplementedError("Use mock") + + def cleanup(self): + pass + + +class TestWrapCommand: + def test_basic_shape(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("echo hello", "/tmp") + + assert "source" in wrapped + assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped + assert "eval 'echo hello'" in wrapped + assert "__hermes_ec=$?" in wrapped + assert "export -p >" in wrapped + assert "pwd -P >" in wrapped + assert env._cwd_marker in wrapped + assert "exit $__hermes_ec" in wrapped + + def test_no_snapshot_skips_source(self): + env = _TestableEnv() + env._snapshot_ready = False + wrapped = env._wrap_command("echo hello", "/tmp") + + assert "source" not in wrapped + + def test_single_quote_escaping(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("echo 'hello world'", "/tmp") + + assert "eval 'echo '\\''hello world'\\'''" in wrapped + + def test_tilde_not_quoted(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("ls", "~") + + assert "cd ~" in wrapped + assert "cd '~'" not in wrapped + + def test_cd_failure_exit_126(self): + env = _TestableEnv() + env._snapshot_ready = True + wrapped = env._wrap_command("ls", "/nonexistent") + + assert "exit 126" in wrapped + + +class TestExtractCwdFromOutput: + def test_happy_path(self): + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"hello\n{marker}/home/user{marker}\n", + } + env._extract_cwd_from_output(result) + + assert env.cwd == "/home/user" + assert marker not in result["output"] + + def test_missing_marker(self): + env = _TestableEnv() + result = {"output": "hello world\n"} + env._extract_cwd_from_output(result) + + assert env.cwd == "/tmp" # unchanged + + def test_marker_in_command_output(self): + """If the marker appears in command output AND as the real marker, + rfind grabs the last (real) one.""" + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n", + } + env._extract_cwd_from_output(result) + + assert env.cwd == "/correct/path" + + def test_output_cleaned(self): + env = _TestableEnv() + marker = env._cwd_marker + result = { + "output": f"hello\n{marker}/tmp{marker}\n", + } + env._extract_cwd_from_output(result) + + assert "hello" in result["output"] + assert marker not in result["output"] + + +class TestEmbedStdinHeredoc: + def test_heredoc_format(self): + result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world") + + assert result.startswith("cat << '") + assert "hello world" in result + assert "HERMES_STDIN_" in result + + def test_unique_delimiter_each_call(self): + r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data") + r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data") + + # Extract delimiters + d1 = r1.split("'")[1] + d2 = r2.split("'")[1] + assert d1 != d2 # UUID-based, should be unique + + +class TestInitSessionFailure: + def test_snapshot_ready_false_on_failure(self): + env = _TestableEnv() + + def failing_run_bash(*args, **kwargs): + raise RuntimeError("bash not found") + + env._run_bash = failing_run_bash + env.init_session() + + assert env._snapshot_ready is False + + def test_login_flag_when_snapshot_not_ready(self): + """When _snapshot_ready=False, execute() should pass login=True to _run_bash.""" + env = _TestableEnv() + env._snapshot_ready = False + + calls = [] + def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None): + calls.append({"login": login}) + # Return a mock process handle + mock = MagicMock() + mock.poll.return_value = 0 + mock.returncode = 0 + mock.stdout = iter([]) + return mock + + env._run_bash = mock_run_bash + env.execute("echo test") + + assert len(calls) == 1 + assert calls[0]["login"] is True + + +class TestCwdMarker: + def test_marker_contains_session_id(self): + env = _TestableEnv() + assert env._session_id in env._cwd_marker + + def test_unique_per_instance(self): + env1 = _TestableEnv() + env2 = _TestableEnv() + assert env1._cwd_marker != env2._cwd_marker diff --git a/tests/tools/test_daytona_environment.py b/tests/tools/test_daytona_environment.py index 04e6347955..7f5aa17ece 100644 --- a/tests/tools/test_daytona_environment.py +++ b/tests/tools/test_daytona_environment.py @@ -59,8 +59,8 @@ def daytona_sdk(monkeypatch): @pytest.fixture() def make_env(daytona_sdk, monkeypatch): """Factory that creates a DaytonaEnvironment with a mocked SDK.""" - # Prevent is_interrupted from interfering - monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + # Prevent is_interrupted from interfering — patch where it's used (base.py) + monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False) # Prevent skills/credential sync from consuming mock exec calls monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: []) monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None) @@ -221,41 +221,45 @@ class TestCleanup: class TestExecute: def test_basic_command(self, make_env): sb = _make_sandbox() - # First call: $HOME detection; subsequent calls: actual commands + # Calls: (1) $HOME detection, (2) init_session bootstrap, (3) actual command sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="hello", exit_code=0), # actual cmd ] sb.state = "started" env = make_env(sandbox=sb) result = env.execute("echo hello") - assert result["output"] == "hello" + assert "hello" in result["output"] assert result["returncode"] == 0 - def test_command_wrapped_with_shell_timeout(self, make_env): + def test_sdk_timeout_passed_to_exec(self, make_env): + """SDK native timeout is passed to sandbox.process.exec().""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="ok", exit_code=0), ] sb.state = "started" env = make_env(sandbox=sb, timeout=42) env.execute("echo hello") - # The command sent to exec should be wrapped with `timeout N sh -c '...'` + # The exec call should receive timeout= kwarg (SDK native timeout) call_args = sb.process.exec.call_args_list[-1] + assert call_args[1]["timeout"] == 42 + # The command should NOT have a shell `timeout` prefix cmd = call_args[0][0] - assert cmd.startswith("timeout 42 sh -c ") - # SDK timeout param should NOT be passed - assert "timeout" not in call_args[1] + assert not cmd.startswith("timeout ") def test_timeout_returns_exit_code_124(self, make_env): - """Shell timeout utility returns exit code 124.""" + """SDK-level timeout surfaces as exit code 124 via _wait_for_process.""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), - _make_exec_response(result="", exit_code=124), + _make_exec_response(result="", exit_code=0), # init_session + _make_exec_response(result="", exit_code=124), # actual cmd ] sb.state = "started" env = make_env(sandbox=sb) @@ -267,6 +271,7 @@ class TestExecute: sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="not found", exit_code=127), ] sb.state = "started" @@ -279,6 +284,7 @@ class TestExecute: sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="ok", exit_code=0), ] sb.state = "started" @@ -286,39 +292,47 @@ class TestExecute: env.execute("python3", stdin_data="print('hi')") # Check that the command passed to exec contains heredoc markers - # (single quotes get shell-escaped by shlex.quote, so check components) + # Base class uses HERMES_STDIN_ prefix for heredoc delimiters call_args = sb.process.exec.call_args_list[-1] cmd = call_args[0][0] - assert "HERMES_EOF_" in cmd + assert "HERMES_STDIN_" in cmd assert "print" in cmd assert "hi" in cmd - def test_custom_cwd_passed_through(self, make_env): + def test_custom_cwd_in_command_wrapper(self, make_env): + """CWD is handled by _wrap_command() in the command string, not as a kwarg.""" sb = _make_sandbox() sb.process.exec.side_effect = [ _make_exec_response(result="/root"), + _make_exec_response(result="", exit_code=0), # init_session _make_exec_response(result="/tmp", exit_code=0), ] sb.state = "started" env = make_env(sandbox=sb) env.execute("pwd", cwd="/tmp") - call_kwargs = sb.process.exec.call_args_list[-1][1] - assert call_kwargs["cwd"] == "/tmp" + # CWD should be embedded in the command string via _wrap_command + call_args = sb.process.exec.call_args_list[-1] + cmd = call_args[0][0] + assert "cd /tmp" in cmd + # CWD should NOT be passed as a kwarg to exec + assert "cwd" not in call_args[1] def test_daytona_error_triggers_retry(self, make_env, daytona_sdk): sb = _make_sandbox() sb.state = "started" sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME + _make_exec_response(result="", exit_code=0), # init_session daytona_sdk.DaytonaError("transient"), # first attempt fails _make_exec_response(result="ok", exit_code=0), # retry succeeds ] env = make_env(sandbox=sb) result = env.execute("echo retry") - assert result["output"] == "ok" - assert result["returncode"] == 0 + # DaytonaError now surfaces directly through _ThreadedProcessHandle + # (no retry logic) — the error becomes returncode=1 + assert result["returncode"] == 1 # --------------------------------------------------------------------------- @@ -359,14 +373,18 @@ class TestInterrupt: calls["n"] += 1 if calls["n"] == 1: return _make_exec_response(result="/root") # $HOME detection + if calls["n"] == 2: + return _make_exec_response(result="", exit_code=0) # init_session event.wait(timeout=5) # simulate long-running command return _make_exec_response(result="done", exit_code=0) sb.process.exec.side_effect = exec_side_effect env = make_env(sandbox=sb) + # is_interrupted is checked by base.py's _wait_for_process, + # patch where it's actually referenced (base.py's local binding) monkeypatch.setattr( - "tools.environments.daytona.is_interrupted", lambda: True + "tools.environments.base.is_interrupted", lambda: True ) try: result = env.execute("sleep 10") @@ -377,23 +395,24 @@ class TestInterrupt: # --------------------------------------------------------------------------- -# Retry exhaustion +# DaytonaError surfaces directly (no retry) # --------------------------------------------------------------------------- class TestRetryExhausted: def test_both_attempts_fail(self, make_env, daytona_sdk): + """DaytonaError surfaces directly as rc=1 (retry logic was removed).""" sb = _make_sandbox() sb.state = "started" sb.process.exec.side_effect = [ _make_exec_response(result="/root"), # $HOME - daytona_sdk.DaytonaError("fail1"), # first attempt - daytona_sdk.DaytonaError("fail2"), # retry + _make_exec_response(result="", exit_code=0), # init_session + daytona_sdk.DaytonaError("fail1"), # actual command fails ] env = make_env(sandbox=sb) result = env.execute("echo x") + # Error surfaces directly through _ThreadedProcessHandle (rc=1) assert result["returncode"] == 1 - assert "Daytona execution error" in result["output"] # --------------------------------------------------------------------------- diff --git a/tests/tools/test_docker_environment.py b/tests/tools/test_docker_environment.py index ce98217cf8..498ef9d506 100644 --- a/tests/tools/test_docker_environment.py +++ b/tests/tools/test_docker_environment.py @@ -245,43 +245,42 @@ def _make_execute_only_env(forward_env=None): env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124} env._container_id = "test-container" env._docker_exe = "/usr/bin/docker" + # Base class attributes needed by unified execute() + env._session_id = "test123" + env._snapshot_path = "/tmp/hermes-snap-test123.sh" + env._cwd_file = "/tmp/hermes-cwd-test123.txt" + env._cwd_marker = "__HERMES_CWD_test123__" + env._snapshot_ready = True + env._last_sync_time = None + env._init_env_args = [] return env -def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch): +def test_init_env_args_uses_hermes_dotenv_for_allowlisted_env(monkeypatch): + """_build_init_env_args picks up forwarded env vars from .env file at init time.""" env = _make_execute_only_env(["GITHUB_TOKEN"]) - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.delenv("GITHUB_TOKEN", raising=False) monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - result = env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - assert result["returncode"] == 0 - assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0] + assert "GITHUB_TOKEN=value_from_dotenv" in args_str -def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch): +def test_init_env_args_prefers_shell_env_over_hermes_dotenv(monkeypatch): + """Shell env vars take priority over .env file values in init env args.""" env = _make_execute_only_env(["GITHUB_TOKEN"]) - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0] - assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0] + assert "GITHUB_TOKEN=value_from_shell" in args_str + assert "value_from_dotenv" not in args_str # ── docker_env tests ────────────────────────────────────────────── @@ -302,64 +301,46 @@ def test_docker_env_appears_in_run_command(monkeypatch): assert "GNUPGHOME=/root/.gnupg" in run_args_str -def test_docker_env_appears_in_exec_command(monkeypatch): - """Explicit docker_env values should also be passed via -e at docker exec time.""" +def test_docker_env_appears_in_init_env_args(monkeypatch): + """Explicit docker_env values should appear in _build_init_env_args.""" env = _make_execute_only_env() env._env = {"MY_VAR": "my_value"} - popen_calls = [] - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) + args = env._build_init_env_args() + args_str = " ".join(args) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - - env.execute("echo hi") - - assert popen_calls, "Popen should have been called" - assert "MY_VAR=my_value" in popen_calls[0] + assert "MY_VAR=my_value" in args_str -def test_forward_env_overrides_docker_env(monkeypatch): +def test_forward_env_overrides_docker_env_in_init_args(monkeypatch): """docker_forward_env should override docker_env for the same key.""" env = _make_execute_only_env(forward_env=["MY_KEY"]) env._env = {"MY_KEY": "static_value"} - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("MY_KEY", "dynamic_value") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) - cmd_str = " ".join(popen_calls[0]) - assert "MY_KEY=dynamic_value" in cmd_str - assert "MY_KEY=static_value" not in cmd_str + assert "MY_KEY=dynamic_value" in args_str + assert "MY_KEY=static_value" not in args_str -def test_docker_env_and_forward_env_merge(monkeypatch): +def test_docker_env_and_forward_env_merge_in_init_args(monkeypatch): """docker_env and docker_forward_env with different keys should both appear.""" env = _make_execute_only_env(forward_env=["TOKEN"]) env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"} - popen_calls = [] - - def _fake_popen(cmd, **kwargs): - popen_calls.append(cmd) - return _FakePopen(cmd, **kwargs) monkeypatch.setenv("TOKEN", "secret123") monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {}) - monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen) - env.execute("echo hi") + args = env._build_init_env_args() + args_str = " ".join(args) + + assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in args_str + assert "TOKEN=secret123" in args_str - cmd_str = " ".join(popen_calls[0]) - assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in cmd_str - assert "TOKEN=secret123" in cmd_str def test_normalize_env_dict_filters_invalid_keys(): diff --git a/tests/tools/test_file_tools_live.py b/tests/tools/test_file_tools_live.py index 4daf19a030..6c3500eb88 100644 --- a/tests/tools/test_file_tools_live.py +++ b/tests/tools/test_file_tools_live.py @@ -22,21 +22,19 @@ import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[2])) -from tools.environments.local import ( - LocalEnvironment, - _clean_shell_noise, - _extract_fenced_output, - _OUTPUT_FENCE, - _SHELL_NOISE_SUBSTRINGS, -) +from tools.environments.local import LocalEnvironment from tools.file_operations import ShellFileOperations # ── Shared noise detection ─────────────────────────────────────────────── -# Every known shell noise pattern. If ANY of these appear in output that -# isn't explicitly expected, the test fails with a clear message. +# Known shell noise patterns that should never appear in command output. -_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [ +_ALL_NOISE_PATTERNS = [ + "bash: cannot set terminal process group", + "bash: no job control in this shell", + "no job control in this shell", + "cannot set terminal process group", + "tcsetattr: Inappropriate ioctl for device", "bash: ", "Inappropriate ioctl", "Auto-suggestions:", @@ -88,134 +86,6 @@ def populated_dir(tmp_path): return tmp_path -# ── _clean_shell_noise unit tests ──────────────────────────────────────── - -class TestCleanShellNoise: - def test_single_noise_line(self): - output = "bash: no job control in this shell\nhello world\n" - result = _clean_shell_noise(output) - assert result == "hello world\n" - - def test_double_noise_lines(self): - output = ( - "bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n" - "bash: no job control in this shell\n" - "actual output here\n" - ) - result = _clean_shell_noise(output) - assert result == "actual output here\n" - _assert_clean(result) - - def test_tcsetattr_noise(self): - output = ( - "bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n" - "real content\n" - ) - result = _clean_shell_noise(output) - assert result == "real content\n" - _assert_clean(result) - - def test_triple_noise_lines(self): - output = ( - "bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n" - "bash: no job control in this shell\n" - "bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n" - "clean\n" - ) - result = _clean_shell_noise(output) - assert result == "clean\n" - - def test_no_noise_untouched(self): - assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n" - - def test_empty_string(self): - assert _clean_shell_noise("") == "" - - def test_only_noise_produces_empty(self): - output = "bash: no job control in this shell\n" - result = _clean_shell_noise(output) - _assert_clean(result) - - def test_noise_in_middle_not_stripped(self): - """Noise in the middle is real output and should be preserved.""" - output = "real\nbash: no job control in this shell\nmore real\n" - result = _clean_shell_noise(output) - assert result == output - - def test_zsh_restored_session(self): - output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_saving_session_trailing(self): - output = "hello\nSaving session...completed.\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_oh_my_zsh_banner(self): - output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - def test_zsh_full_noise_sandwich(self): - """Both leading and trailing zsh noise stripped.""" - output = ( - "Restored session: Mon Mar 2\n" - "command not found: docker\n" - "Oh My Zsh on!\n" - "actual output\n" - "Saving session...completed.\n" - ) - result = _clean_shell_noise(output) - assert result == "actual output\n" - - def test_last_login_stripped(self): - output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n" - result = _clean_shell_noise(output) - assert result == "hello\n" - - -# ── _extract_fenced_output unit tests ──────────────────────────────────── - -class TestExtractFencedOutput: - def test_normal_fenced_output(self): - raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n" - assert _extract_fenced_output(raw) == "hello world\n" - - def test_no_trailing_newline(self): - """printf output with no trailing newline is preserved.""" - raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise" - assert _extract_fenced_output(raw) == "exact" - - def test_no_fences_falls_back(self): - """Without fences, falls back to pattern-based cleaning.""" - raw = "bash: no job control in this shell\nhello\n" - result = _extract_fenced_output(raw) - assert result == "hello\n" - - def test_only_start_fence(self): - """Only start fence (e.g. user command called exit).""" - raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n" - result = _extract_fenced_output(raw) - assert result == "hello\n" - - def test_user_outputs_fence_string(self): - """If user command outputs the fence marker, it is preserved.""" - raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise" - result = _extract_fenced_output(raw) - # first fence -> last fence captures the middle including user's fence - assert _OUTPUT_FENCE in result - assert "real\n" in result - - def test_empty_command_output(self): - raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise" - assert _extract_fenced_output(raw) == "" - - def test_multiline_output(self): - raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n" - assert _extract_fenced_output(raw) == "line1\nline2\nline3\n" - - # ── LocalEnvironment.execute() ─────────────────────────────────────────── class TestLocalEnvironmentExecute: diff --git a/tests/tools/test_local_persistent.py b/tests/tools/test_local_persistent.py deleted file mode 100644 index 5b9ce2e238..0000000000 --- a/tests/tools/test_local_persistent.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Tests for the local persistent shell backend.""" - -import glob as glob_mod - -import pytest - -from tools.environments.local import LocalEnvironment -from tools.environments.persistent_shell import PersistentShellMixin - - -class TestLocalConfig: - def test_local_persistent_default_false(self, monkeypatch): - monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False) - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is False - - def test_local_persistent_true(self, monkeypatch): - monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true") - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is True - - def test_local_persistent_yes(self, monkeypatch): - monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes") - from tools.terminal_tool import _get_env_config - assert _get_env_config()["local_persistent"] is True - - -class TestMergeOutput: - def test_stdout_only(self): - assert PersistentShellMixin._merge_output("out", "") == "out" - - def test_stderr_only(self): - assert PersistentShellMixin._merge_output("", "err") == "err" - - def test_both(self): - assert PersistentShellMixin._merge_output("out", "err") == "out\nerr" - - def test_empty(self): - assert PersistentShellMixin._merge_output("", "") == "" - - def test_strips_trailing_newlines(self): - assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr" - - -class TestLocalOneShotRegression: - def test_echo(self): - env = LocalEnvironment(persistent=False) - r = env.execute("echo hello") - assert r["returncode"] == 0 - assert "hello" in r["output"] - env.cleanup() - - def test_exit_code(self): - env = LocalEnvironment(persistent=False) - r = env.execute("exit 42") - assert r["returncode"] == 42 - env.cleanup() - - def test_state_does_not_persist(self): - env = LocalEnvironment(persistent=False) - env.execute("export HERMES_ONESHOT_LOCAL=yes") - r = env.execute("echo $HERMES_ONESHOT_LOCAL") - assert r["output"].strip() == "" - env.cleanup() - - def test_oneshot_heredoc_does_not_leak_fence_wrapper(self): - """Heredoc closing line must not be merged with the fence wrapper tail.""" - env = LocalEnvironment(persistent=False) - cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF" - r = env.execute(cmd) - env.cleanup() - assert r["returncode"] == 0 - assert "heredoc body line" in r["output"] - assert "__hermes_rc" not in r["output"] - assert "printf '" not in r["output"] - assert "exit $" not in r["output"] - - -class TestLocalPersistent: - @pytest.fixture - def env(self): - e = LocalEnvironment(persistent=True) - yield e - e.cleanup() - - def test_echo(self, env): - r = env.execute("echo hello-persistent") - assert r["returncode"] == 0 - assert "hello-persistent" in r["output"] - - def test_env_var_persists(self, env): - env.execute("export HERMES_LOCAL_PERSIST_TEST=works") - r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST") - assert r["output"].strip() == "works" - - def test_cwd_persists(self, env): - env.execute("cd /tmp") - r = env.execute("pwd") - assert r["output"].strip() == "/tmp" - - def test_exit_code(self, env): - r = env.execute("(exit 42)") - assert r["returncode"] == 42 - - def test_stderr(self, env): - r = env.execute("echo oops >&2") - assert r["returncode"] == 0 - assert "oops" in r["output"] - - def test_multiline_output(self, env): - r = env.execute("echo a; echo b; echo c") - lines = r["output"].strip().splitlines() - assert lines == ["a", "b", "c"] - - def test_timeout_then_recovery(self, env): - r = env.execute("sleep 999", timeout=2) - assert r["returncode"] in (124, 130) - r = env.execute("echo alive") - assert r["returncode"] == 0 - assert "alive" in r["output"] - - def test_large_output(self, env): - r = env.execute("seq 1 1000") - assert r["returncode"] == 0 - lines = r["output"].strip().splitlines() - assert len(lines) == 1000 - assert lines[0] == "1" - assert lines[-1] == "1000" - - def test_shell_variable_persists(self, env): - env.execute("MY_LOCAL_VAR=hello123") - r = env.execute("echo $MY_LOCAL_VAR") - assert r["output"].strip() == "hello123" - - def test_cleanup_removes_temp_files(self, env): - env.execute("echo warmup") - prefix = env._temp_prefix - assert len(glob_mod.glob(f"{prefix}-*")) > 0 - env.cleanup() - remaining = glob_mod.glob(f"{prefix}-*") - assert remaining == [] - - def test_state_does_not_leak_between_instances(self): - env1 = LocalEnvironment(persistent=True) - env2 = LocalEnvironment(persistent=True) - try: - env1.execute("export LEAK_TEST=from_env1") - r = env2.execute("echo $LEAK_TEST") - assert r["output"].strip() == "" - finally: - env1.cleanup() - env2.cleanup() - - def test_special_characters_in_command(self, env): - r = env.execute("echo 'hello world'") - assert r["output"].strip() == "hello world" - - def test_pipe_command(self, env): - r = env.execute("echo hello | tr 'h' 'H'") - assert r["output"].strip() == "Hello" - - def test_multiple_commands_semicolon(self, env): - r = env.execute("X=42; echo $X") - assert r["output"].strip() == "42" diff --git a/tests/tools/test_managed_modal_environment.py b/tests/tools/test_managed_modal_environment.py index ded9cd3d4b..1d7241e0b7 100644 --- a/tests/tools/test_managed_modal_environment.py +++ b/tests/tools/test_managed_modal_environment.py @@ -110,7 +110,7 @@ class _FakeResponse: def test_managed_modal_execute_polls_until_completed(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] poll_count = {"value": 0} @@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch): def test_managed_modal_execute_cancels_on_interrupt(monkeypatch): interrupt_event = _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] @@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch): def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] def fake_request(method, url, headers=None, json=None, timeout=None): if method == "POST" and url.endswith("/v1/sandboxes"): @@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough(): def test_managed_modal_execute_times_out_and_cancels(monkeypatch): _install_fake_tools_package() managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") - modal_common = sys.modules["tools.environments.modal_common"] + modal_common = sys.modules["tools.environments.modal_utils"] calls = [] monotonic_values = iter([0.0, 12.5]) diff --git a/tests/tools/test_modal_sandbox_fixes.py b/tests/tools/test_modal_sandbox_fixes.py index e1baf13d98..570ef5b218 100644 --- a/tests/tools/test_modal_sandbox_fixes.py +++ b/tests/tools/test_modal_sandbox_fixes.py @@ -231,20 +231,20 @@ class TestEnsurepipFix: """Verify the pip fix is applied in the ModalEnvironment init.""" def test_modal_environment_creates_image_with_setup_commands(self): - """ModalEnvironment.__init__ should create a modal.Image with pip fix.""" + """_resolve_modal_image should create a modal.Image with pip fix.""" try: - from tools.environments.modal import ModalEnvironment + from tools.environments.modal import _resolve_modal_image except ImportError: pytest.skip("tools.environments.modal not importable") import inspect - source = inspect.getsource(ModalEnvironment.__init__) + source = inspect.getsource(_resolve_modal_image) assert "ensurepip" in source, ( - "ModalEnvironment should include ensurepip fix " + "_resolve_modal_image should include ensurepip fix " "for Modal's legacy image builder" ) assert "setup_dockerfile_commands" in source, ( - "ModalEnvironment should use setup_dockerfile_commands " + "_resolve_modal_image should use setup_dockerfile_commands " "to fix pip before Modal's bootstrap" ) diff --git a/tests/tools/test_modal_snapshot_isolation.py b/tests/tools/test_modal_snapshot_isolation.py index a3d0eeacd7..b58454cc07 100644 --- a/tests/tools/test_modal_snapshot_isolation.py +++ b/tests/tools/test_modal_snapshot_isolation.py @@ -85,11 +85,47 @@ def _install_modal_test_modules( def _prepare_command(self, command: str): return command, None - sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment) + def init_session(self): + pass + + # Stub _ThreadedProcessHandle: modal.py imports it but only uses it at + # runtime inside _run_bash; the snapshot-isolation tests never call _run_bash, + # so a class placeholder is sufficient. + class _DummyThreadedProcessHandle: + def __init__(self, exec_fn, cancel_fn=None): + pass + + def _load_json_store(path): + if path.exists(): + try: + return json.loads(path.read_text()) + except Exception: + pass + return {} + + def _save_json_store(path, data): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + def _file_mtime_key(host_path): + try: + st = Path(host_path).stat() + return (st.st_mtime, st.st_size) + except OSError: + return None + + sys.modules["tools.environments.base"] = types.SimpleNamespace( + BaseEnvironment=_DummyBaseEnvironment, + _ThreadedProcessHandle=_DummyThreadedProcessHandle, + _load_json_store=_load_json_store, + _save_json_store=_save_json_store, + _file_mtime_key=_file_mtime_key, + ) sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False) sys.modules["tools.credential_files"] = types.SimpleNamespace( get_credential_file_mounts=lambda: [], iter_skills_files=lambda: [], + iter_cache_files=lambda: [], ) from_id_calls: list[str] = [] diff --git a/tests/tools/test_ssh_environment.py b/tests/tools/test_ssh_environment.py index 9f514e9a90..f6ee967170 100644 --- a/tests/tools/test_ssh_environment.py +++ b/tests/tools/test_ssh_environment.py @@ -43,7 +43,7 @@ class TestBuildSSHCommand: lambda *a, **k: MagicMock(stdout=iter([]), stderr=iter([]), stdin=MagicMock())) - monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None) + monkeypatch.setattr("tools.environments.base.time.sleep", lambda _: None) def test_base_flags(self): env = SSHEnvironment(host="h", user="u") diff --git a/tests/tools/test_terminal_none_command_guard.py b/tests/tools/test_terminal_none_command_guard.py new file mode 100644 index 0000000000..05455836d1 --- /dev/null +++ b/tests/tools/test_terminal_none_command_guard.py @@ -0,0 +1,21 @@ +"""Regression tests for invalid/None terminal command handling.""" + +import json + +from tools.terminal_tool import _transform_sudo_command, terminal_tool + + +def test_transform_sudo_command_none_returns_cleanly(): + transformed, sudo_stdin = _transform_sudo_command(None) + + assert transformed is None + assert sudo_stdin is None + + +def test_terminal_tool_none_command_returns_clean_error(): + result = json.loads(terminal_tool(None)) # type: ignore[arg-type] + + assert result["exit_code"] == -1 + assert result["status"] == "error" + assert "expected string" in result["error"].lower() + assert "nonetype" in result["error"].lower() diff --git a/tests/tools/test_terminal_tool.py b/tests/tools/test_terminal_tool.py new file mode 100644 index 0000000000..42ed693a2e --- /dev/null +++ b/tests/tools/test_terminal_tool.py @@ -0,0 +1,90 @@ +"""Regression tests for sudo detection and sudo password handling.""" + +import tools.terminal_tool as terminal_tool + + +def setup_function(): + terminal_tool._cached_sudo_password = "" + + +def teardown_function(): + terminal_tool._cached_sudo_password = "" + + +def test_searching_for_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "rg --line-number --no-heading --with-filename 'sudo' . | head -n 20" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_printf_literal_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "printf '%s\\n' sudo" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_non_command_argument_named_sudo_does_not_trigger_rewrite(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + command = "grep -n sudo README.md" + transformed, sudo_stdin = terminal_tool._transform_sudo_command(command) + + assert transformed == command + assert sudo_stdin is None + + +def test_actual_sudo_command_uses_configured_password(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "testpass") + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo apt install -y ripgrep") + + assert transformed == "sudo -S -p '' apt install -y ripgrep" + assert sudo_stdin == "testpass\n" + + +def test_actual_sudo_after_leading_env_assignment_is_rewritten(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "testpass") + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("DEBUG=1 sudo whoami") + + assert transformed == "DEBUG=1 sudo -S -p '' whoami" + assert sudo_stdin == "testpass\n" + + +def test_explicit_empty_sudo_password_tries_empty_without_prompt(monkeypatch): + monkeypatch.setenv("SUDO_PASSWORD", "") + monkeypatch.setenv("HERMES_INTERACTIVE", "1") + + def _fail_prompt(*_args, **_kwargs): + raise AssertionError("interactive sudo prompt should not run for explicit empty password") + + monkeypatch.setattr(terminal_tool, "_prompt_for_sudo_password", _fail_prompt) + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo true") + + assert transformed == "sudo -S -p '' true" + assert sudo_stdin == "\n" + + +def test_cached_sudo_password_is_used_when_env_is_unset(monkeypatch): + monkeypatch.delenv("SUDO_PASSWORD", raising=False) + monkeypatch.delenv("HERMES_INTERACTIVE", raising=False) + terminal_tool._cached_sudo_password = "cached-pass" + + transformed, sudo_stdin = terminal_tool._transform_sudo_command("echo ok && sudo whoami") + + assert transformed == "echo ok && sudo -S -p '' whoami" + assert sudo_stdin == "cached-pass\n" diff --git a/tests/tools/test_threaded_process_handle.py b/tests/tools/test_threaded_process_handle.py new file mode 100644 index 0000000000..4e6fbdb0d6 --- /dev/null +++ b/tests/tools/test_threaded_process_handle.py @@ -0,0 +1,144 @@ +"""Tests for _ThreadedProcessHandle — the adapter for SDK backends.""" + +import threading +import time + +from tools.environments.base import _ThreadedProcessHandle + + +class TestBasicExecution: + def test_successful_execution(self): + def exec_fn(): + return ("hello world", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 0 + output = handle.stdout.read() + assert "hello world" in output + + def test_nonzero_exit_code(self): + def exec_fn(): + return ("error occurred", 42) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 42 + output = handle.stdout.read() + assert "error occurred" in output + + def test_exception_in_exec_fn(self): + def exec_fn(): + raise RuntimeError("boom") + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 1 + + def test_empty_output(self): + def exec_fn(): + return ("", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + assert handle.returncode == 0 + output = handle.stdout.read() + assert output == "" + + +class TestPolling: + def test_poll_returns_none_while_running(self): + event = threading.Event() + + def exec_fn(): + event.wait(timeout=5) + return ("done", 0) + + handle = _ThreadedProcessHandle(exec_fn) + assert handle.poll() is None + + event.set() + handle.wait(timeout=5) + assert handle.poll() == 0 + + def test_poll_returns_returncode_when_done(self): + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + assert handle.poll() == 0 + + +class TestCancelFn: + def test_cancel_fn_called_on_kill(self): + called = threading.Event() + + def cancel(): + called.set() + + def exec_fn(): + time.sleep(10) + return ("", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) + handle.kill() + assert called.is_set() + + def test_cancel_fn_none_is_safe(self): + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None) + handle.kill() # should not raise + handle.wait(timeout=5) + assert handle.returncode == 0 + + def test_cancel_fn_exception_swallowed(self): + def cancel(): + raise RuntimeError("cancel failed") + + def exec_fn(): + return ("ok", 0) + + handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) + handle.kill() # should not raise despite cancel raising + handle.wait(timeout=5) + + +class TestStdoutPipe: + def test_stdout_is_readable(self): + def exec_fn(): + return ("line1\nline2\nline3\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + lines = handle.stdout.readlines() + assert len(lines) == 3 + assert lines[0] == "line1\n" + + def test_stdout_iterable(self): + def exec_fn(): + return ("a\nb\nc\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + collected = list(handle.stdout) + assert len(collected) == 3 + + def test_unicode_output(self): + def exec_fn(): + return ("hello 世界 🌍\n", 0) + + handle = _ThreadedProcessHandle(exec_fn) + handle.wait(timeout=5) + + output = handle.stdout.read() + assert "世界" in output + assert "🌍" in output diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index aa4cd0863f..f0d61210ff 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -18,7 +18,7 @@ Architecture (two transports): 2. Parent ships both files to the remote environment 3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.) 4. Tool calls are written as request files; a polling thread on the parent - reads them via execute_oneshot(), dispatches, and writes response files + reads them via env.execute(), dispatches, and writes response files 5. The script polls for response files and continues In both cases, only the script's stdout is returned to the LLM; intermediate @@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None: quotes are fine. """ encoded = base64.b64encode(content.encode("utf-8")).decode("ascii") - env.execute_oneshot( + env.execute( f"echo '{encoded}' | base64 -d > {remote_path}", cwd="/", timeout=30, @@ -555,9 +555,9 @@ def _rpc_poll_loop( ): """Poll the remote filesystem for tool call requests and dispatch them. - Runs in a background thread. Uses ``env.execute_oneshot()`` so it can - operate concurrently with the script-execution thread that holds - ``env.execute()`` (important for persistent-shell backends like SSH). + Runs in a background thread. Each ``env.execute()`` spawns an + independent process, so these calls run safely concurrent with the + script-execution thread. """ from model_tools import handle_function_call @@ -566,7 +566,7 @@ def _rpc_poll_loop( while not stop_event.is_set(): try: # List pending request files (skip .tmp partials) - ls_result = env.execute_oneshot( + ls_result = env.execute( f"ls -1 {rpc_dir}/req_* 2>/dev/null || true", cwd="/", timeout=10, @@ -590,7 +590,7 @@ def _rpc_poll_loop( call_start = time.monotonic() # Read request - read_result = env.execute_oneshot( + read_result = env.execute( f"cat {req_file}", cwd="/", timeout=10, @@ -600,7 +600,7 @@ def _rpc_poll_loop( except (json.JSONDecodeError, ValueError): logger.debug("Malformed RPC request in %s", req_file) # Remove bad request to avoid infinite retry - env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5) + env.execute(f"rm -f {req_file}", cwd="/", timeout=5) continue tool_name = request.get("tool", "") @@ -664,7 +664,7 @@ def _rpc_poll_loop( encoded_result = base64.b64encode( tool_result.encode("utf-8") ).decode("ascii") - env.execute_oneshot( + env.execute( f"echo '{encoded_result}' | base64 -d > {res_file}.tmp" f" && mv {res_file}.tmp {res_file}", cwd="/", @@ -672,7 +672,7 @@ def _rpc_poll_loop( ) # Remove the request file - env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5) + env.execute(f"rm -f {req_file}", cwd="/", timeout=5) except Exception as e: if not stop_event.is_set(): @@ -717,7 +717,7 @@ def _execute_remote( try: # Verify Python is available on the remote - py_check = env.execute_oneshot( + py_check = env.execute( "command -v python3 >/dev/null 2>&1 && echo OK", cwd="/", timeout=15, ) @@ -734,7 +734,7 @@ def _execute_remote( }) # Create sandbox directory on remote - env.execute_oneshot( + env.execute( f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10, ) @@ -806,7 +806,7 @@ def _execute_remote( # Clean up remote sandbox dir try: - env.execute_oneshot( + env.execute( f"rm -rf {sandbox_dir}", cwd="/", timeout=15, ) except Exception: diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 595ad8bc71..ccb8bc6f63 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -455,7 +455,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr }, "deliver": { "type": "string", - "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" + "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, bluebubbles, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" }, "skills": { "type": "array", diff --git a/tools/environments/base.py b/tools/environments/base.py index 21b698ec0c..31ce0e17de 100644 --- a/tools/environments/base.py +++ b/tools/environments/base.py @@ -1,11 +1,27 @@ -"""Base class for all Hermes execution environment backends.""" +"""Base class for all Hermes execution environment backends. -from abc import ABC, abstractmethod +Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process. +A session snapshot (env vars, functions, aliases) is captured once at init and +re-sourced before each command. CWD persists via in-band stdout markers (remote) +or a temp file (local). +""" + +import json +import logging import os +import shlex import subprocess +import threading +import time +import uuid +from abc import ABC, abstractmethod from pathlib import Path +from typing import IO, Callable, Protocol from hermes_constants import get_hermes_home +from tools.interrupt import is_interrupted + +logger = logging.getLogger(__name__) def get_sandbox_dir() -> Path: @@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path: return p -class BaseEnvironment(ABC): - """Common interface for all Hermes execution backends. +# --------------------------------------------------------------------------- +# Shared constants and utilities +# --------------------------------------------------------------------------- - Subclasses implement execute() and cleanup(). Shared helpers eliminate - duplicated subprocess boilerplate across backends. +_SYNC_INTERVAL_SECONDS = 5.0 + + +def _pipe_stdin(proc: subprocess.Popen, data: str) -> None: + """Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks.""" + + def _write(): + try: + proc.stdin.write(data) + proc.stdin.close() + except (BrokenPipeError, OSError): + pass + + threading.Thread(target=_write, daemon=True).start() + + +def _popen_bash( + cmd: list[str], stdin_data: str | None = None, **kwargs +) -> subprocess.Popen: + """Spawn a subprocess with standard stdout/stderr/stdin setup. + + If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`. + Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass + this and call :func:`_pipe_stdin` directly. """ + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, + text=True, + **kwargs, + ) + if stdin_data is not None: + _pipe_stdin(proc, stdin_data) + return proc + + +def _load_json_store(path: Path) -> dict: + """Load a JSON file as a dict, returning ``{}`` on any error.""" + if path.exists(): + try: + return json.loads(path.read_text()) + except Exception: + pass + return {} + + +def _save_json_store(path: Path, data: dict) -> None: + """Write *data* as pretty-printed JSON to *path*.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2)) + + +def _file_mtime_key(host_path: str) -> tuple[float, int] | None: + """Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable.""" + try: + st = Path(host_path).stat() + return (st.st_mtime, st.st_size) + except OSError: + return None + + +# --------------------------------------------------------------------------- +# ProcessHandle protocol +# --------------------------------------------------------------------------- + + +class ProcessHandle(Protocol): + """Duck type that every backend's _run_bash() must return. + + subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona) + return _ThreadedProcessHandle which adapts their blocking calls. + """ + + def poll(self) -> int | None: ... + def kill(self) -> None: ... + def wait(self, timeout: float | None = None) -> int: ... + + @property + def stdout(self) -> IO[str] | None: ... + + @property + def returncode(self) -> int | None: ... + + +class _ThreadedProcessHandle: + """Adapter for SDK backends (Modal, Daytona) that have no real subprocess. + + Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background + thread and exposes a ProcessHandle-compatible interface. An optional + ``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation + (e.g. Modal sandbox.terminate, Daytona sandbox.stop). + """ + + def __init__( + self, + exec_fn: Callable[[], tuple[str, int]], + cancel_fn: Callable[[], None] | None = None, + ): + self._cancel_fn = cancel_fn + self._done = threading.Event() + self._returncode: int | None = None + self._error: Exception | None = None + + # Pipe for stdout — drain thread in _wait_for_process reads the read end. + read_fd, write_fd = os.pipe() + self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace") + self._write_fd = write_fd + + def _worker(): + try: + output, exit_code = exec_fn() + self._returncode = exit_code + # Write output into the pipe so drain thread picks it up. + try: + os.write(self._write_fd, output.encode("utf-8", errors="replace")) + except OSError: + pass + except Exception as exc: + self._error = exc + self._returncode = 1 + finally: + try: + os.close(self._write_fd) + except OSError: + pass + self._done.set() + + t = threading.Thread(target=_worker, daemon=True) + t.start() + + @property + def stdout(self): + return self._stdout + + @property + def returncode(self) -> int | None: + return self._returncode + + def poll(self) -> int | None: + return self._returncode if self._done.is_set() else None + + def kill(self): + if self._cancel_fn: + try: + self._cancel_fn() + except Exception: + pass + + def wait(self, timeout: float | None = None) -> int: + self._done.wait(timeout=timeout) + return self._returncode + + +# --------------------------------------------------------------------------- +# CWD marker for remote backends +# --------------------------------------------------------------------------- + + +def _cwd_marker(session_id: str) -> str: + return f"__HERMES_CWD_{session_id}__" + + +# --------------------------------------------------------------------------- +# BaseEnvironment +# --------------------------------------------------------------------------- + + +class BaseEnvironment(ABC): + """Common interface and unified execution flow for all Hermes backends. + + Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class + provides ``execute()`` with session snapshot sourcing, CWD tracking, + interrupt handling, and timeout enforcement. + """ + + # Subclasses that embed stdin as a heredoc (Modal, Daytona) set this. + _stdin_mode: str = "pipe" # "pipe" or "heredoc" + + # Snapshot creation timeout (override for slow cold-starts). + _snapshot_timeout: int = 30 def __init__(self, cwd: str, timeout: int, env: dict = None): self.cwd = cwd self.timeout = timeout self.env = env or {} - @abstractmethod - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command, return {"output": str, "returncode": int}.""" - ... + self._session_id = uuid.uuid4().hex[:12] + self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh" + self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt" + self._cwd_marker = _cwd_marker(self._session_id) + self._snapshot_ready = False + self._last_sync_time: float | None = ( + None # set to 0 by backends that need file sync + ) + + # ------------------------------------------------------------------ + # Abstract methods + # ------------------------------------------------------------------ + + def _run_bash( + self, + cmd_string: str, + *, + login: bool = False, + timeout: int = 120, + stdin_data: str | None = None, + ) -> ProcessHandle: + """Spawn a bash process to run *cmd_string*. + + Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle). + Must be overridden by every backend. + """ + raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()") @abstractmethod def cleanup(self): """Release backend resources (container, instance, connection).""" ... + # ------------------------------------------------------------------ + # Session snapshot (init_session) + # ------------------------------------------------------------------ + + def init_session(self): + """Capture login shell environment into a snapshot file. + + Called once after backend construction. On success, sets + ``_snapshot_ready = True`` so subsequent commands source the snapshot + instead of running with ``bash -l``. + """ + # Full capture: env vars, functions (filtered), aliases, shell options. + bootstrap = ( + f"export -p > {self._snapshot_path}\n" + f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n" + f"alias -p >> {self._snapshot_path}\n" + f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n" + f"echo 'set +e' >> {self._snapshot_path}\n" + f"echo 'set +u' >> {self._snapshot_path}\n" + f"pwd -P > {self._cwd_file} 2>/dev/null || true\n" + f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n" + ) + try: + proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout) + result = self._wait_for_process(proc, timeout=self._snapshot_timeout) + self._snapshot_ready = True + self._update_cwd(result) + logger.info( + "Session snapshot created (session=%s, cwd=%s)", + self._session_id, + self.cwd, + ) + except Exception as exc: + logger.warning( + "init_session failed (session=%s): %s — " + "falling back to bash -l per command", + self._session_id, + exc, + ) + self._snapshot_ready = False + + # ------------------------------------------------------------------ + # Command wrapping + # ------------------------------------------------------------------ + + def _wrap_command(self, command: str, cwd: str) -> str: + """Build the full bash script that sources snapshot, cd's, runs command, + re-dumps env vars, and emits CWD markers.""" + escaped = command.replace("'", "'\\''") + + parts = [] + + # Source snapshot (env vars from previous commands) + if self._snapshot_ready: + parts.append(f"source {self._snapshot_path} 2>/dev/null || true") + + # cd to working directory — let bash expand ~ natively + quoted_cwd = ( + shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd + ) + parts.append(f"cd {quoted_cwd} || exit 126") + + # Run the actual command + parts.append(f"eval '{escaped}'") + parts.append("__hermes_ec=$?") + + # Re-dump env vars to snapshot (last-writer-wins for concurrent calls) + if self._snapshot_ready: + parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true") + + # Write CWD to file (local reads this) and stdout marker (remote parses this) + parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true") + # Use a distinct line for the marker. The leading \n ensures + # the marker starts on its own line even if the command doesn't + # end with a newline (e.g. printf 'exact'). We'll strip this + # injected newline in _extract_cwd_from_output. + parts.append( + f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"" + ) + parts.append("exit $__hermes_ec") + + return "\n".join(parts) + + # ------------------------------------------------------------------ + # Stdin heredoc embedding (for SDK backends) + # ------------------------------------------------------------------ + + @staticmethod + def _embed_stdin_heredoc(command: str, stdin_data: str) -> str: + """Append stdin_data as a shell heredoc to the command string.""" + delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}" + return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}" + + # ------------------------------------------------------------------ + # Process lifecycle + # ------------------------------------------------------------------ + + def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict: + """Poll-based wait with interrupt checking and stdout draining. + + Shared across all backends — not overridden. + """ + output_chunks: list[str] = [] + + def _drain(): + try: + for line in proc.stdout: + output_chunks.append(line) + except UnicodeDecodeError: + output_chunks.clear() + output_chunks.append( + "[binary output detected — raw bytes not displayable]" + ) + except (ValueError, OSError): + pass + + drain_thread = threading.Thread(target=_drain, daemon=True) + drain_thread.start() + deadline = time.monotonic() + timeout + + while proc.poll() is None: + if is_interrupted(): + self._kill_process(proc) + drain_thread.join(timeout=2) + return { + "output": "".join(output_chunks) + "\n[Command interrupted]", + "returncode": 130, + } + if time.monotonic() > deadline: + self._kill_process(proc) + drain_thread.join(timeout=2) + partial = "".join(output_chunks) + timeout_msg = f"\n[Command timed out after {timeout}s]" + return { + "output": partial + timeout_msg + if partial + else timeout_msg.lstrip(), + "returncode": 124, + } + time.sleep(0.2) + + drain_thread.join(timeout=5) + + try: + proc.stdout.close() + except Exception: + pass + + return {"output": "".join(output_chunks), "returncode": proc.returncode} + + def _kill_process(self, proc: ProcessHandle): + """Terminate a process. Subclasses may override for process-group kill.""" + try: + proc.kill() + except (ProcessLookupError, PermissionError, OSError): + pass + + # ------------------------------------------------------------------ + # CWD extraction + # ------------------------------------------------------------------ + + def _update_cwd(self, result: dict): + """Extract CWD from command output. Override for local file-based read.""" + self._extract_cwd_from_output(result) + + def _extract_cwd_from_output(self, result: dict): + """Parse the __HERMES_CWD_{session}__ marker from stdout output. + + Updates self.cwd and strips the marker from result["output"]. + Used by remote backends (Docker, SSH, Modal, Daytona, Singularity). + """ + output = result.get("output", "") + marker = self._cwd_marker + last = output.rfind(marker) + if last == -1: + return + + # Find the opening marker before this closing one + search_start = max(0, last - 4096) # CWD path won't be >4KB + first = output.rfind(marker, search_start, last) + if first == -1 or first == last: + return + + cwd_path = output[first + len(marker) : last].strip() + if cwd_path: + self.cwd = cwd_path + + # Strip the marker line AND the \n we injected before it. + # The wrapper emits: printf '\n__MARKER__%s__MARKER__\n' + # So the output looks like: \n__MARKER__path__MARKER__\n + # We want to remove everything from the injected \n onwards. + line_start = output.rfind("\n", 0, first) + if line_start == -1: + line_start = first + line_end = output.find("\n", last + len(marker)) + line_end = line_end + 1 if line_end != -1 else len(output) + + result["output"] = output[:line_start] + output[line_end:] + + # ------------------------------------------------------------------ + # Hooks + # ------------------------------------------------------------------ + + def _before_execute(self): + """Rate-limited file sync before each command. + + Backends that need pre-command sync set ``self._last_sync_time = 0`` + in ``__init__`` and override :meth:`_sync_files`. Backends needing + extra pre-exec logic (e.g. Daytona sandbox restart check) override + this method and call ``super()._before_execute()``. + """ + if self._last_sync_time is not None: + now = time.monotonic() + if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS: + self._sync_files() + self._last_sync_time = now + + def _sync_files(self): + """Push files to remote environment. Called rate-limited by _before_execute.""" + pass + + # ------------------------------------------------------------------ + # Unified execute() + # ------------------------------------------------------------------ + + def execute( + self, + command: str, + cwd: str = "", + *, + timeout: int | None = None, + stdin_data: str | None = None, + ) -> dict: + """Execute a command, return {"output": str, "returncode": int}.""" + self._before_execute() + + exec_command, sudo_stdin = self._prepare_command(command) + effective_timeout = timeout or self.timeout + effective_cwd = cwd or self.cwd + + # Merge sudo stdin with caller stdin + if sudo_stdin is not None and stdin_data is not None: + effective_stdin = sudo_stdin + stdin_data + elif sudo_stdin is not None: + effective_stdin = sudo_stdin + else: + effective_stdin = stdin_data + + # Embed stdin as heredoc for backends that need it + if effective_stdin and self._stdin_mode == "heredoc": + exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin) + effective_stdin = None + + wrapped = self._wrap_command(exec_command, effective_cwd) + + # Use login shell if snapshot failed (so user's profile still loads) + login = not self._snapshot_ready + + proc = self._run_bash( + wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin + ) + result = self._wait_for_process(proc, timeout=effective_timeout) + self._update_cwd(result) + + return result + + # ------------------------------------------------------------------ + # Shared helpers + # ------------------------------------------------------------------ + def stop(self): """Alias for cleanup (compat with older callers).""" self.cleanup() @@ -57,53 +544,12 @@ class BaseEnvironment(ABC): except Exception: pass - # ------------------------------------------------------------------ - # Shared helpers (eliminate duplication across backends) - # ------------------------------------------------------------------ - def _prepare_command(self, command: str) -> tuple[str, str | None]: - """Transform sudo commands if SUDO_PASSWORD is available. - - Returns: - (transformed_command, sudo_stdin) — see _transform_sudo_command - for the full contract. Callers that drive a subprocess directly - should prepend sudo_stdin (when not None) to any stdin_data they - pass to Popen. Callers that embed stdin via heredoc (modal, - daytona) handle sudo_stdin in their own execute() method. - """ + """Transform sudo commands if SUDO_PASSWORD is available.""" from tools.terminal_tool import _transform_sudo_command + return _transform_sudo_command(command) - def _build_run_kwargs(self, timeout: int | None, - stdin_data: str | None = None) -> dict: - """Build common subprocess.run kwargs for non-interactive execution.""" - kw = { - "text": True, - "timeout": timeout or self.timeout, - "encoding": "utf-8", - "errors": "replace", - "stdout": subprocess.PIPE, - "stderr": subprocess.STDOUT, - } - if stdin_data is not None: - kw["input"] = stdin_data - else: - kw["stdin"] = subprocess.DEVNULL - return kw - - def execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Execute a command bypassing any persistent shell. - - Safe for concurrent use alongside a long-running execute() call. - Backends that maintain a persistent shell (SSH, Local) override this - to route through their oneshot path, avoiding the shell lock. - Non-persistent backends delegate to execute(). - """ - return self.execute(command, cwd=cwd, timeout=timeout, - stdin_data=stdin_data) - def _timeout_result(self, timeout: int | None) -> dict: """Standard return dict when a command times out.""" return { diff --git a/tools/environments/daytona.py b/tools/environments/daytona.py index e52459d8b5..60958fd353 100644 --- a/tools/environments/daytona.py +++ b/tools/environments/daytona.py @@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions. """ import logging -import time import math import shlex import threading -import uuid import warnings from pathlib import Path from typing import Dict, Optional -from tools.environments.base import BaseEnvironment -from tools.interrupt import is_interrupted +from tools.environments.base import ( + BaseEnvironment, + _ThreadedProcessHandle, + _file_mtime_key, +) logger = logging.getLogger(__name__) @@ -24,22 +25,25 @@ logger = logging.getLogger(__name__) class DaytonaEnvironment(BaseEnvironment): """Daytona cloud sandbox execution backend. - Uses stopped/started sandbox lifecycle for filesystem persistence - instead of snapshots, making it faster and stateless on the host. + Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls. + cancel_fn wired to sandbox.stop() for interrupt support. + Shell timeout wrapper preserved (SDK timeout unreliable). """ + _stdin_mode = "heredoc" + def __init__( self, image: str, cwd: str = "/home/daytona", timeout: int = 60, cpu: int = 1, - memory: int = 5120, # MB (hermes convention) - disk: int = 10240, # MB (Daytona platform max is 10GB) + memory: int = 5120, + disk: int = 10240, persistent_filesystem: bool = True, task_id: str = "default", ): - self._requested_cwd = cwd + requested_cwd = cwd super().__init__(cwd=cwd, timeout=timeout) from daytona import ( @@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment): self._persistent = persistent_filesystem self._task_id = task_id self._SandboxState = SandboxState + self._DaytonaError = DaytonaError self._daytona = Daytona() self._sandbox = None self._lock = threading.Lock() + self._last_sync_time: float = 0 memory_gib = max(1, math.ceil(memory / 1024)) disk_gib = max(1, math.ceil(disk / 1024)) if disk_gib > 10: warnings.warn( f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). " - f"Capping to 10GB. Set container_disk: 10240 in config to silence this.", + f"Capping to 10GB.", stacklevel=2, ) disk_gib = 10 @@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment): labels = {"hermes_task_id": task_id} sandbox_name = f"hermes-{task_id}" - # Try to resume an existing sandbox for this task if self._persistent: - # 1. Try name-based lookup (new path) try: self._sandbox = self._daytona.get(sandbox_name) self._sandbox.start() @@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment): task_id, e) self._sandbox = None - # 2. Legacy fallback: find sandbox created before the naming migration if self._sandbox is None: try: page = self._daytona.list(labels=labels, page=1, limit=1) @@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment): task_id, e) self._sandbox = None - # Create a fresh sandbox if we don't have one if self._sandbox is None: self._sandbox = self._daytona.create( CreateSandboxFromImageParams( @@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment): logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id) - # Detect remote home dir first so mounts go to the right place. + # Detect remote home dir self._remote_home = "/root" try: home = self._sandbox.process.exec("echo $HOME").result.strip() if home: self._remote_home = home - if self._requested_cwd in ("~", "/home/daytona"): + if requested_cwd in ("~", "/home/daytona"): self.cwd = home except Exception: pass logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd) - # Track synced files to avoid redundant uploads. - # Key: remote_path, Value: (mtime, size) self._synced_files: Dict[str, tuple] = {} - - # Upload credential files and skills directory into the sandbox. - self._sync_skills_and_credentials() + self._sync_files() + self.init_session() def _upload_if_changed(self, host_path: str, remote_path: str) -> bool: - """Upload a file if its mtime/size changed since last sync.""" - hp = Path(host_path) - try: - stat = hp.stat() - file_key = (stat.st_mtime, stat.st_size) - except OSError: + file_key = _file_mtime_key(host_path) + if file_key is None: return False if self._synced_files.get(remote_path) == file_key: return False @@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment): logger.debug("Daytona: upload failed %s: %s", host_path, e) return False - def _sync_skills_and_credentials(self) -> None: - """Upload changed credential files and skill files into the sandbox.""" + def _sync_files(self) -> None: container_base = f"{self._remote_home}/.hermes" try: from tools.credential_files import get_credential_file_mounts, iter_skills_files - for mount_entry in get_credential_file_mounts(): remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1) - if self._upload_if_changed(mount_entry["host_path"], remote_path): - logger.debug("Daytona: synced credential %s", remote_path) - + self._upload_if_changed(mount_entry["host_path"], remote_path) for entry in iter_skills_files(container_base=container_base): - if self._upload_if_changed(entry["host_path"], entry["container_path"]): - logger.debug("Daytona: synced skill %s", entry["container_path"]) + self._upload_if_changed(entry["host_path"], entry["container_path"]) except Exception as e: logger.debug("Daytona: could not sync skills/credentials: %s", e) @@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment): self._sandbox.start() logger.info("Daytona: restarted sandbox %s", self._sandbox.id) - def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict: - """Run exec in a background thread with interrupt polling. - - The Daytona SDK's exec(timeout=...) parameter is unreliable (the - server-side timeout is not enforced and the SDK has no client-side - fallback), so we wrap the command with the shell ``timeout`` utility - which reliably kills the process and returns exit code 124. - """ - # Wrap with shell `timeout` to enforce the deadline reliably. - # Add a small buffer so the shell timeout fires before any SDK-level - # timeout would, giving us a clean exit code 124. - timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}" - - result_holder: dict = {"value": None, "error": None} - - def _run(): - try: - response = self._sandbox.process.exec( - timed_command, cwd=cwd, - ) - result_holder["value"] = { - "output": response.result or "", - "returncode": response.exit_code, - } - except Exception as e: - result_holder["error"] = e - - t = threading.Thread(target=_run, daemon=True) - t.start() - # Wait for timeout + generous buffer for network/SDK overhead - deadline = time.monotonic() + timeout + 10 - while t.is_alive(): - t.join(timeout=0.2) - if is_interrupted(): - with self._lock: - try: - self._sandbox.stop() - except Exception: - pass - return { - "output": "[Command interrupted - Daytona sandbox stopped]", - "returncode": 130, - } - if time.monotonic() > deadline: - # Shell timeout didn't fire and SDK is hung — force stop - with self._lock: - try: - self._sandbox.stop() - except Exception: - pass - return self._timeout_result(timeout) - - if result_holder["error"]: - return {"error": result_holder["error"]} - return result_holder["value"] - - def execute(self, command: str, cwd: str = "", *, - timeout: Optional[int] = None, - stdin_data: Optional[str] = None) -> dict: + def _before_execute(self): + """Ensure sandbox is ready, then rate-limited file sync via base class.""" with self._lock: self._ensure_sandbox_ready() - # Incremental sync before each command so mid-session credential - # refreshes and skill updates are picked up. - self._sync_skills_and_credentials() + super()._before_execute() - if stdin_data is not None: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - while marker in stdin_data: - marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" - command = f"{command} << '{marker}'\n{stdin_data}\n{marker}" + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None): + """Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call.""" + sandbox = self._sandbox + lock = self._lock - exec_command, sudo_stdin = self._prepare_command(command) + def cancel(): + with lock: + try: + sandbox.stop() + except Exception: + pass - # Daytona sandboxes execute commands via the Daytona SDK and cannot - # pipe subprocess stdin directly the way a local Popen can. When a - # sudo password is present, use a shell-level pipe from printf so that - # the password feeds sudo -S without appearing as an echo argument - # embedded in the shell string. The password is still visible in the - # remote sandbox's command line, but it is not exposed on the user's - # local machine — which is the primary threat being mitigated. - if sudo_stdin is not None: - import shlex - exec_command = ( - f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}" - ) - effective_cwd = cwd or self.cwd or None - effective_timeout = timeout or self.timeout + if login: + shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}" + else: + shell_cmd = f"bash -c {shlex.quote(cmd_string)}" - result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout) + def exec_fn() -> tuple[str, int]: + response = sandbox.process.exec(shell_cmd, timeout=timeout) + return (response.result or "", response.exit_code) - if "error" in result: - from daytona import DaytonaError - err = result["error"] - if isinstance(err, DaytonaError): - with self._lock: - try: - self._ensure_sandbox_ready() - except Exception: - return {"output": f"Daytona execution error: {err}", "returncode": 1} - result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout) - if "error" not in result: - return result - return {"output": f"Daytona execution error: {err}", "returncode": 1} - - return result + return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) def cleanup(self): with self._lock: diff --git a/tools/environments/docker.py b/tools/environments/docker.py index b97040d4e0..59a2377961 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -8,18 +8,14 @@ persistence via bind mounts. import logging import os import re -import shlex import shutil import subprocess import sys -import threading -import time import uuid from typing import Optional -from tools.environments.base import BaseEnvironment +from tools.environments.base import BaseEnvironment, _popen_bash from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST -from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) @@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment): self._container_id = result.stdout.strip() logger.info(f"Started container {container_name} ({self._container_id[:12]})") + # Build the init-time env forwarding args (used only by init_session + # to inject host env vars into the snapshot; subsequent commands get + # them from the snapshot file). + self._init_env_args = self._build_init_env_args() + + # Initialize session snapshot inside the container + self.init_session() + + def _build_init_env_args(self) -> list[str]: + """Build -e KEY=VALUE args for injecting host env vars into init_session. + + These are used once during init_session() so that export -p captures + them into the snapshot. Subsequent execute() calls don't need -e flags. + """ + exec_env: dict[str, str] = dict(self._env) + + explicit_forward_keys = set(self._forward_env) + passthrough_keys: set[str] = set() + try: + from tools.env_passthrough import get_all_passthrough + passthrough_keys = set(get_all_passthrough()) + except Exception: + pass + # Explicit docker_forward_env entries are an intentional opt-in and must + # win over the generic Hermes secret blocklist. Only implicit passthrough + # keys are filtered. + forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST) + hermes_env = _load_hermes_env_vars() if forward_keys else {} + for key in sorted(forward_keys): + value = os.getenv(key) + if value is None: + value = hermes_env.get(key) + if value is not None: + exec_env[key] = value + + args = [] + for key in sorted(exec_env): + args.extend(["-e", f"{key}={exec_env[key]}"]) + return args + + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn a bash process inside the Docker container.""" + assert self._container_id, "Container not started" + cmd = [self._docker_exe, "exec"] + if stdin_data is not None: + cmd.append("-i") + + # Only inject -e env args during init_session (login=True). + # Subsequent commands get env vars from the snapshot. + if login: + cmd.extend(self._init_env_args) + + cmd.extend([self._container_id]) + + if login: + cmd.extend(["bash", "-l", "-c", cmd_string]) + else: + cmd.extend(["bash", "-c", cmd_string]) + + return _popen_bash(cmd, stdin_data) + @staticmethod def _storage_opt_supported() -> bool: """Check if Docker's storage driver supports --storage-opt size=. @@ -471,112 +530,6 @@ class DockerEnvironment(BaseEnvironment): logger.debug("Docker --storage-opt support: %s", _storage_opt_ok) return _storage_opt_ok - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - exec_command, sudo_stdin = self._prepare_command(command) - work_dir = cwd or self.cwd - effective_timeout = timeout or self.timeout - - # Merge sudo password (if any) with caller-supplied stdin_data. - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - # docker exec -w doesn't expand ~, so prepend a cd into the command. - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - exec_command = f"cd ~ && {exec_command}" - work_dir = "/" - elif work_dir.startswith("~/"): - exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}" - work_dir = "/" - - assert self._container_id, "Container not started" - cmd = [self._docker_exe, "exec"] - if effective_stdin is not None: - cmd.append("-i") - cmd.extend(["-w", work_dir]) - # Build the per-exec environment: start with explicit docker_env values - # (static config), then overlay docker_forward_env / skill env_passthrough - # (dynamic from host process). Forward values take precedence. - exec_env: dict[str, str] = dict(self._env) - - explicit_forward_keys = set(self._forward_env) - passthrough_keys: set[str] = set() - try: - from tools.env_passthrough import get_all_passthrough - passthrough_keys = set(get_all_passthrough()) - except Exception: - pass - # Explicit docker_forward_env entries are an intentional opt-in and must - # win over the generic Hermes secret blocklist. Only implicit passthrough - # keys are filtered. - forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST) - hermes_env = _load_hermes_env_vars() if forward_keys else {} - for key in sorted(forward_keys): - value = os.getenv(key) - if value is None: - value = hermes_env.get(key) - if value is not None: - exec_env[key] = value - - for key in sorted(exec_env): - cmd.extend(["-e", f"{key}={exec_env[key]}"]) - cmd.extend([self._container_id, "bash", "-lc", exec_command]) - - try: - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - except Exception as e: - return {"output": f"Docker execution error: {e}", "returncode": 1} - def cleanup(self): """Stop and remove the container. Bind-mount dirs persist if persistent=True.""" if self._container_id: diff --git a/tools/environments/local.py b/tools/environments/local.py index 27282b6ef6..d3bb344829 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -1,42 +1,22 @@ -"""Local execution environment with interrupt support and non-blocking I/O.""" +"""Local execution environment — spawn-per-call with session snapshot.""" -import glob import os import platform import shutil import signal import subprocess -import threading -import time + +from tools.environments.base import BaseEnvironment, _pipe_stdin _IS_WINDOWS = platform.system() == "Windows" -from tools.environments.base import BaseEnvironment -from tools.environments.persistent_shell import PersistentShellMixin -from tools.interrupt import is_interrupted - -# Unique marker to isolate real command output from shell init/exit noise. -# printf (no trailing newline) keeps the boundaries clean for splitting. -_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__" # Hermes-internal env vars that should NOT leak into terminal subprocesses. -# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls -# but can break external CLIs (e.g. codex) that also honor them. -# See: https://github.com/NousResearch/hermes-agent/issues/1002 -# -# Built dynamically from the provider registry so new providers are -# automatically covered without manual blocklist maintenance. _HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_" def _build_provider_env_blocklist() -> frozenset: - """Derive the blocklist from provider, tool, and gateway config. - - Automatically picks up api_key_env_vars and base_url_env_var from - every registered provider, plus tool/messaging env vars from the - optional config registry, so new Hermes-managed secrets are blocked - in subprocesses without having to maintain multiple static lists. - """ + """Derive the blocklist from provider, tool, and gateway config.""" blocked: set[str] = set() try: @@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset: except ImportError: pass - # Vars not covered above but still Hermes-internal / conflict-prone. blocked.update({ "OPENAI_BASE_URL", "OPENAI_API_KEY", - "OPENAI_API_BASE", # legacy alias + "OPENAI_API_BASE", "OPENAI_ORG_ID", "OPENAI_ORGANIZATION", "OPENROUTER_API_KEY", "ANTHROPIC_BASE_URL", - "ANTHROPIC_TOKEN", # OAuth token (not in registry as env var) + "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN", "LLM_MODEL", - # Expanded isolation for other major providers (Issue #1002) - "GOOGLE_API_KEY", # Gemini / Google AI Studio - "DEEPSEEK_API_KEY", # DeepSeek - "MISTRAL_API_KEY", # Mistral AI - "GROQ_API_KEY", # Groq - "TOGETHER_API_KEY", # Together AI - "PERPLEXITY_API_KEY", # Perplexity - "COHERE_API_KEY", # Cohere - "FIREWORKS_API_KEY", # Fireworks AI - "XAI_API_KEY", # xAI (Grok) - "HELICONE_API_KEY", # LLM Observability proxy + "GOOGLE_API_KEY", + "DEEPSEEK_API_KEY", + "MISTRAL_API_KEY", + "GROQ_API_KEY", + "TOGETHER_API_KEY", + "PERPLEXITY_API_KEY", + "COHERE_API_KEY", + "FIREWORKS_API_KEY", + "XAI_API_KEY", + "HELICONE_API_KEY", "PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", - # Gateway/runtime config not represented in OPTIONAL_ENV_VARS. "TELEGRAM_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL_NAME", "DISCORD_HOME_CHANNEL", @@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset: "EMAIL_HOME_ADDRESS", "EMAIL_HOME_ADDRESS_NAME", "GATEWAY_ALLOWED_USERS", - # Skills Hub / GitHub app auth paths and aliases. "GH_TOKEN", "GITHUB_APP_ID", "GITHUB_APP_PRIVATE_KEY_PATH", "GITHUB_APP_INSTALLATION_ID", - # Remote sandbox backend credentials. "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", "DAYTONA_API_KEY", @@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist() def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict: - """Filter Hermes-managed secrets from a subprocess environment. - - `_HERMES_FORCE_` entries in ``extra_env`` opt a blocked variable back in - intentionally for callers that truly need it. Vars registered via - :mod:`tools.env_passthrough` (skill-declared or user-configured) also - bypass the blocklist. - """ + """Filter Hermes-managed secrets from a subprocess environment.""" try: from tools.env_passthrough import is_env_passthrough as _is_passthrough except Exception: @@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non def _find_bash() -> str: - """Find bash for command execution. - - The fence wrapper uses bash syntax (semicolons, $?, printf), so we - must use bash — not the user's $SHELL which could be fish/zsh/etc. - On Windows: uses Git Bash (bundled with Git for Windows). - """ + """Find bash for command execution.""" if not _IS_WINDOWS: return ( shutil.which("bash") or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None) or ("/bin/bash" if os.path.isfile("/bin/bash") else None) - or os.environ.get("SHELL") # last resort: whatever they have + or os.environ.get("SHELL") or "/bin/sh" ) - # Windows: look for Git Bash (installed with Git for Windows). - # Allow override via env var (same pattern as Claude Code). custom = os.environ.get("HERMES_GIT_BASH_PATH") if custom and os.path.isfile(custom): return custom - # shutil.which finds bash.exe if Git\bin is on PATH found = shutil.which("bash") if found: return found - # Check common Git for Windows install locations for candidate in ( os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"), os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"), @@ -209,60 +169,7 @@ def _find_bash() -> str: _find_shell = _find_bash -# Noise lines emitted by interactive shells when stdin is not a terminal. -# Used as a fallback when output fence markers are missing. -_SHELL_NOISE_SUBSTRINGS = ( - # bash - "bash: cannot set terminal process group", - "bash: no job control in this shell", - "no job control in this shell", - "cannot set terminal process group", - "tcsetattr: Inappropriate ioctl for device", - # zsh / oh-my-zsh / macOS terminal session - "Restored session:", - "Saving session...", - "Last login:", - "command not found:", - "Oh My Zsh", - "compinit:", -) - - -def _clean_shell_noise(output: str) -> str: - """Strip shell startup/exit warnings that leak when using -i without a TTY. - - Removes lines matching known noise patterns from both the beginning - and end of the output. Lines in the middle are left untouched. - """ - - def _is_noise(line: str) -> bool: - return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS) - - lines = output.split("\n") - - # Strip leading noise - while lines and _is_noise(lines[0]): - lines.pop(0) - - # Strip trailing noise (walk backwards, skip empty lines from split) - end = len(lines) - 1 - while end >= 0 and (not lines[end] or _is_noise(lines[end])): - end -= 1 - - if end < 0: - return "" - - cleaned = lines[: end + 1] - result = "\n".join(cleaned) - - # Preserve trailing newline if original had one - if output.endswith("\n") and result and not result.endswith("\n"): - result += "\n" - return result - - -# Standard PATH entries for environments with minimal PATH (e.g. systemd services). -# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon). +# Standard PATH entries for environments with minimal PATH. _SANE_PATH = ( "/opt/homebrew/bin:/opt/homebrew/sbin:" "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" @@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict: return run_env -def _extract_fenced_output(raw: str) -> str: - """Extract real command output from between fence markers. - - The execute() method wraps each command with printf(FENCE) markers. - This function finds the first and last fence and returns only the - content between them, which is the actual command output free of - any shell init/exit noise. - - Falls back to pattern-based _clean_shell_noise if fences are missing. - """ - first = raw.find(_OUTPUT_FENCE) - if first == -1: - return _clean_shell_noise(raw) - - start = first + len(_OUTPUT_FENCE) - last = raw.rfind(_OUTPUT_FENCE) - - if last <= first: - # Only start fence found (e.g. user command called `exit`) - return _clean_shell_noise(raw[start:]) - - return raw[start:last] - - -class LocalEnvironment(PersistentShellMixin, BaseEnvironment): +class LocalEnvironment(BaseEnvironment): """Run commands directly on the host machine. - Features: - - Popen + polling for interrupt support (user can cancel mid-command) - - Background stdout drain thread to prevent pipe buffer deadlocks - - stdin_data support for piping content (bypasses ARG_MAX limits) - - sudo -S transform via SUDO_PASSWORD env var - - Uses interactive login shell so full user env is available - - Optional persistent shell mode (cwd/env vars survive across calls) + Spawn-per-call: every execute() spawns a fresh bash process. + Session snapshot preserves env vars across calls. + CWD persists via file-based read after each command. """ - def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None, - persistent: bool = False): + def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) - self.persistent = persistent - if self.persistent: - self._init_persistent_shell() + self.init_session() - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-local-{self._session_id}" - - def _spawn_shell_process(self) -> subprocess.Popen: - user_shell = _find_bash() - run_env = _make_run_env(self.env) - return subprocess.Popen( - [user_shell, "-l"], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - env=run_env, - preexec_fn=None if _IS_WINDOWS else os.setsid, - ) - - def _read_temp_files(self, *paths: str) -> list[str]: - results = [] - for path in paths: - if os.path.exists(path): - with open(path) as f: - results.append(f.read()) - else: - results.append("") - return results - - def _kill_shell_children(self): - if self._shell_pid is None: - return - try: - subprocess.run( - ["pkill", "-P", str(self._shell_pid)], - capture_output=True, timeout=5, - ) - except (subprocess.TimeoutExpired, FileNotFoundError): - pass - - def _cleanup_temp_files(self): - for f in glob.glob(f"{self._temp_prefix}-*"): - if os.path.exists(f): - os.remove(f) - - def _execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - work_dir = cwd or self.cwd or os.getcwd() - effective_timeout = timeout or self.timeout - exec_command, sudo_stdin = self._prepare_command(command) - - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - user_shell = _find_bash() - # Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line). - # A trailing `; __hermes_rc` glued to `< subprocess.Popen: + bash = _find_bash() + args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string] run_env = _make_run_env(self.env) proc = subprocess.Popen( - [user_shell, "-lic", fenced_cmd], + args, text=True, - cwd=work_dir, env=run_env, encoding="utf-8", errors="replace", stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL, + stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL, preexec_fn=None if _IS_WINDOWS else os.setsid, ) - if effective_stdin is not None: - def _write_stdin(): + if stdin_data is not None: + _pipe_stdin(proc, stdin_data) + + return proc + + def _kill_process(self, proc): + """Kill the entire process group (all children).""" + try: + if _IS_WINDOWS: + proc.terminate() + else: + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - threading.Thread(target=_write_stdin, daemon=True).start() - - _output_chunks: list[str] = [] - - def _drain_stdout(): + proc.wait(timeout=1.0) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + except (ProcessLookupError, PermissionError): try: - for line in proc.stdout: - _output_chunks.append(line) - except ValueError: + proc.kill() + except Exception: pass - finally: - try: - proc.stdout.close() - except Exception: - pass - reader = threading.Thread(target=_drain_stdout, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout + def _update_cwd(self, result: dict): + """Read CWD from temp file (local-only, no round-trip needed).""" + try: + cwd_path = open(self._cwd_file).read().strip() + if cwd_path: + self.cwd = cwd_path + except (OSError, FileNotFoundError): + pass - while proc.poll() is None: - if is_interrupted(): - try: - if _IS_WINDOWS: - proc.terminate() - else: - pgid = os.getpgid(proc.pid) - os.killpg(pgid, signal.SIGTERM) - try: - proc.wait(timeout=1.0) - except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]", - "returncode": 130, - } - if time.monotonic() > deadline: - try: - if _IS_WINDOWS: - proc.terminate() - else: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - except (ProcessLookupError, PermissionError): - proc.kill() - reader.join(timeout=2) - partial = "".join(_output_chunks) - timeout_msg = f"\n[Command timed out after {effective_timeout}s]" - return { - "output": partial + timeout_msg if partial else timeout_msg.lstrip(), - "returncode": 124, - } - time.sleep(0.2) + # Still strip the marker from output so it's not visible + self._extract_cwd_from_output(result) - reader.join(timeout=5) - output = _extract_fenced_output("".join(_output_chunks)) - return {"output": output, "returncode": proc.returncode} + def cleanup(self): + """Clean up temp files.""" + for f in (self._snapshot_path, self._cwd_file): + try: + os.unlink(f) + except OSError: + pass diff --git a/tools/environments/managed_modal.py b/tools/environments/managed_modal.py index a8197bccf2..52b00f19a3 100644 --- a/tools/environments/managed_modal.py +++ b/tools/environments/managed_modal.py @@ -10,7 +10,7 @@ import uuid from dataclasses import dataclass from typing import Any, Dict, Optional -from tools.environments.modal_common import ( +from tools.environments.modal_utils import ( BaseModalExecutionEnvironment, ModalExecStart, PreparedModalExec, diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 7916a2c449..1cb8e47969 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions. """ import asyncio -import json import logging import shlex import threading -from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional from hermes_constants import get_hermes_home -from tools.environments.modal_common import ( - BaseModalExecutionEnvironment, - ModalExecStart, - PreparedModalExec, +from tools.environments.base import ( + BaseEnvironment, + _ThreadedProcessHandle, + _file_mtime_key, + _load_json_store, + _save_json_store, ) logger = logging.getLogger(__name__) @@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json" _DIRECT_SNAPSHOT_NAMESPACE = "direct" -def _load_snapshots() -> Dict[str, str]: - """Load snapshot ID mapping from disk.""" - if _SNAPSHOT_STORE.exists(): - try: - return json.loads(_SNAPSHOT_STORE.read_text()) - except Exception: - pass - return {} +def _load_snapshots() -> dict: + return _load_json_store(_SNAPSHOT_STORE) -def _save_snapshots(data: Dict[str, str]) -> None: - """Persist snapshot ID mapping to disk.""" - _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) - _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) +def _save_snapshots(data: dict) -> None: + _save_json_store(_SNAPSHOT_STORE, data) def _direct_snapshot_key(task_id: str) -> str: @@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str: def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]: - """Return a snapshot id and whether it came from the legacy key format.""" snapshots = _load_snapshots() - namespaced_key = _direct_snapshot_key(task_id) snapshot_id = snapshots.get(namespaced_key) if isinstance(snapshot_id, str) and snapshot_id: return snapshot_id, False - legacy_snapshot_id = snapshots.get(task_id) if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id: return legacy_snapshot_id, True - return None, False def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None: - """Persist the direct Modal snapshot id under the direct namespace.""" snapshots = _load_snapshots() snapshots[_direct_snapshot_key(task_id)] = snapshot_id snapshots.pop(task_id, None) @@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None: def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None: - """Remove direct Modal snapshot entries for a task, including legacy keys.""" snapshots = _load_snapshots() updated = False - for key in (_direct_snapshot_key(task_id), task_id): value = snapshots.get(key) if value is None: @@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non if snapshot_id is None or value == snapshot_id: snapshots.pop(key, None) updated = True - if updated: _save_snapshots(snapshots) def _resolve_modal_image(image_spec: Any) -> Any: - """Convert registry references or snapshot ids into Modal image objects.""" + """Convert registry references or snapshot ids into Modal image objects. + + Includes add_python support for ubuntu/debian images (absorbed from PR 4511). + """ import modal as _modal if not isinstance(image_spec, str): @@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any: if image_spec.startswith("im-"): return _modal.Image.from_id(image_spec) + # PR 4511: add python to ubuntu/debian images that don't have it + lower = image_spec.lower() + add_python = any(base in lower for base in ("ubuntu", "debian")) + + setup_commands = [ + "RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; " + "python -m ensurepip --upgrade --default-pip 2>/dev/null || true", + ] + if add_python: + setup_commands.insert(0, + "RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true" + ) + return _modal.Image.from_registry( image_spec, - setup_dockerfile_commands=[ - "RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; " - "python -m ensurepip --upgrade --default-pip 2>/dev/null || true", - ], + setup_dockerfile_commands=setup_commands, ) @@ -138,19 +135,15 @@ class _AsyncWorker: self._thread.join(timeout=10) -@dataclass -class _DirectModalExecHandle: - thread: threading.Thread - result_holder: Dict[str, Any] +class ModalEnvironment(BaseEnvironment): + """Modal cloud execution via native Modal sandboxes. - -class ModalEnvironment(BaseModalExecutionEnvironment): - """Modal cloud execution via native Modal sandboxes.""" + Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls. + cancel_fn wired to sandbox.terminate for interrupt support. + """ _stdin_mode = "heredoc" - _poll_interval_seconds = 0.2 - _interrupt_output = "[Command interrupted - Modal sandbox terminated]" - _unexpected_error_prefix = "Modal execution error" + _snapshot_timeout = 60 # Modal cold starts can be slow def __init__( self, @@ -170,6 +163,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment): self._app = None self._worker = _AsyncWorker() self._synced_files: Dict[str, tuple] = {} + self._last_sync_time: float = 0 sandbox_kwargs = dict(modal_sandbox_kwargs or {}) @@ -199,27 +193,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment): remote_path=mount_entry["container_path"], ) ) - logger.info( - "Modal: mounting credential %s -> %s", - mount_entry["host_path"], - mount_entry["container_path"], - ) - - # Mount individual skill files (symlinks filtered out). - skills_files = iter_skills_files() - for entry in skills_files: + for entry in iter_skills_files(): cred_mounts.append( _modal.Mount.from_local_file( entry["host_path"], remote_path=entry["container_path"], ) ) - if skills_files: - logger.info("Modal: mounting %d skill files", len(skills_files)) - - # Mount host-side cache files (documents, images, audio, - # screenshots). New files arriving mid-session are picked up - # by _sync_files() before each command execution. cache_files = iter_cache_files() for entry in cache_files: cred_mounts.append( @@ -228,8 +208,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment): remote_path=entry["container_path"], ) ) - if cache_files: - logger.info("Modal: mounting %d cache files", len(cache_files)) except Exception as e: logger.debug("Modal: could not load credential file mounts: %s", e) @@ -243,8 +221,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment): existing_mounts.extend(cred_mounts) create_kwargs["mounts"] = existing_mounts sandbox = await _modal.Sandbox.create.aio( - "sleep", - "infinity", + "sleep", "infinity", image=image_spec, app=app, timeout=int(create_kwargs.pop("timeout", 3600)), @@ -255,57 +232,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment): try: target_image_spec = restored_snapshot_id or image try: - # _resolve_modal_image keeps the Modal bootstrap fix together: - # it applies setup_dockerfile_commands with ensurepip before - # Modal builds registry images, while snapshot ids restore via - # modal.Image.from_id() without rebuilding. effective_image = _resolve_modal_image(target_image_spec) self._app, self._sandbox = self._worker.run_coroutine( - _create_sandbox(effective_image), - timeout=300, + _create_sandbox(effective_image), timeout=300, ) except Exception as exc: if not restored_snapshot_id: raise - logger.warning( "Modal: failed to restore snapshot %s, retrying with base image: %s", - restored_snapshot_id[:20], - exc, + restored_snapshot_id[:20], exc, ) _delete_direct_snapshot(self._task_id, restored_snapshot_id) base_image = _resolve_modal_image(image) self._app, self._sandbox = self._worker.run_coroutine( - _create_sandbox(base_image), - timeout=300, + _create_sandbox(base_image), timeout=300, ) else: if restored_snapshot_id and restored_from_legacy_key: _store_direct_snapshot(self._task_id, restored_snapshot_id) - logger.info( - "Modal: migrated legacy snapshot entry for task %s", - self._task_id, - ) except Exception: self._worker.stop() raise logger.info("Modal: sandbox created (task=%s)", self._task_id) + self.init_session() def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool: - """Push a single file into the sandbox if changed. Returns True if synced.""" - hp = Path(host_path) - try: - stat = hp.stat() - file_key = (stat.st_mtime, stat.st_size) - except OSError: + """Push a single file into the sandbox if changed.""" + file_key = _file_mtime_key(host_path) + if file_key is None: return False - if self._synced_files.get(container_path) == file_key: return False - try: - content = hp.read_bytes() + content = Path(host_path).read_bytes() except Exception: return False @@ -326,85 +287,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment): return True def _sync_files(self) -> None: - """Push credential, skill, and cache files into the running sandbox. - - Runs before each command. Uses mtime+size caching so only changed - files are pushed (~13μs overhead in the no-op case). Cache files - are especially important here — new uploads/screenshots may appear - mid-session after sandbox creation. - """ + """Push credential, skill, and cache files into the running sandbox.""" try: from tools.credential_files import ( get_credential_file_mounts, iter_skills_files, iter_cache_files, ) - for entry in get_credential_file_mounts(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced credential %s", entry["container_path"]) - + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) for entry in iter_skills_files(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced skill file %s", entry["container_path"]) - + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) for entry in iter_cache_files(): - if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]): - logger.debug("Modal: synced cache file %s", entry["container_path"]) + self._push_file_to_sandbox(entry["host_path"], entry["container_path"]) except Exception as e: logger.debug("Modal: file sync failed: %s", e) - def _before_execute(self) -> None: - self._sync_files() + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None): + """Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec.""" + sandbox = self._sandbox + worker = self._worker - def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart: - full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}" - result_holder = {"value": None, "error": None} + def cancel(): + worker.run_coroutine(sandbox.terminate.aio(), timeout=15) - def _run(): - try: - async def _do_execute(): - process = await self._sandbox.exec.aio( - "bash", - "-c", - full_command, - timeout=prepared.timeout, - ) - stdout = await process.stdout.read.aio() - stderr = await process.stderr.read.aio() - exit_code = await process.wait.aio() - if isinstance(stdout, bytes): - stdout = stdout.decode("utf-8", errors="replace") - if isinstance(stderr, bytes): - stderr = stderr.decode("utf-8", errors="replace") - output = stdout - if stderr: - output = f"{stdout}\n{stderr}" if stdout else stderr - return self._result(output, exit_code) + def exec_fn() -> tuple[str, int]: + async def _do(): + args = ["bash"] + if login: + args.extend(["-l", "-c", cmd_string]) + else: + args.extend(["-c", cmd_string]) + process = await sandbox.exec.aio(*args, timeout=timeout) + stdout = await process.stdout.read.aio() + stderr = await process.stderr.read.aio() + exit_code = await process.wait.aio() + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + if isinstance(stderr, bytes): + stderr = stderr.decode("utf-8", errors="replace") + output = stdout + if stderr: + output = f"{stdout}\n{stderr}" if stdout else stderr + return output, exit_code - result_holder["value"] = self._worker.run_coroutine( - _do_execute(), - timeout=prepared.timeout + 30, - ) - except Exception as e: - result_holder["error"] = e + return worker.run_coroutine(_do(), timeout=timeout + 30) - t = threading.Thread(target=_run, daemon=True) - t.start() - return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder)) - - def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None: - if handle.thread.is_alive(): - return None - if handle.result_holder["error"]: - return self._error_result(f"Modal execution error: {handle.result_holder['error']}") - return handle.result_holder["value"] - - def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None: - self._worker.run_coroutine( - self._sandbox.terminate.aio(), - timeout=15, - ) + return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) def cleanup(self): """Snapshot the filesystem (if persistent) then stop the sandbox.""" @@ -426,17 +357,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment): _store_direct_snapshot(self._task_id, snapshot_id) logger.info( "Modal: saved filesystem snapshot %s for task %s", - snapshot_id[:20], - self._task_id, + snapshot_id[:20], self._task_id, ) except Exception as e: logger.warning("Modal: filesystem snapshot failed: %s", e) try: - self._worker.run_coroutine( - self._sandbox.terminate.aio(), - timeout=15, - ) + self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15) except Exception: pass finally: diff --git a/tools/environments/modal_common.py b/tools/environments/modal_utils.py similarity index 91% rename from tools/environments/modal_common.py rename to tools/environments/modal_utils.py index 0affd02095..0db8194719 100644 --- a/tools/environments/modal_common.py +++ b/tools/environments/modal_utils.py @@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str: class BaseModalExecutionEnvironment(BaseEnvironment): - """Common execute() flow for direct and managed Modal transports.""" + """Execution flow for the *managed* Modal transport (gateway-owned sandbox). + + This deliberately overrides :meth:`BaseEnvironment.execute` because the + tool-gateway handles command preparation, CWD tracking, and env-snapshot + management on the server side. The base class's ``_wrap_command`` / + ``_wait_for_process`` / snapshot machinery does not apply here — the + gateway owns that responsibility. See ``ManagedModalEnvironment`` for the + concrete subclass. + """ _stdin_mode = "payload" _poll_interval_seconds = 0.25 @@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment): def _before_execute(self) -> None: """Hook for backends that need pre-exec sync or validation.""" - return None + pass def _prepare_modal_exec( self, diff --git a/tools/environments/persistent_shell.py b/tools/environments/persistent_shell.py deleted file mode 100644 index c4344ff5a1..0000000000 --- a/tools/environments/persistent_shell.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells.""" - -import logging -import shlex -import subprocess -import threading -import time -import uuid -from abc import abstractmethod - -from tools.interrupt import is_interrupted - -logger = logging.getLogger(__name__) - - -class PersistentShellMixin: - """Mixin that adds persistent shell capability to any BaseEnvironment. - - Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``, - ``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``. - """ - - persistent: bool - - @abstractmethod - def _spawn_shell_process(self) -> subprocess.Popen: ... - - @abstractmethod - def _read_temp_files(self, *paths: str) -> list[str]: ... - - @abstractmethod - def _kill_shell_children(self): ... - - @abstractmethod - def _execute_oneshot(self, command: str, cwd: str, *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: ... - - @abstractmethod - def _cleanup_temp_files(self): ... - - _session_id: str = "" - _poll_interval_start: float = 0.01 # initial poll interval (10ms) - _poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands - - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-persistent-{self._session_id}" - - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - - def _init_persistent_shell(self): - self._shell_lock = threading.Lock() - self._shell_proc: subprocess.Popen | None = None - self._shell_alive: bool = False - self._shell_pid: int | None = None - - self._session_id = uuid.uuid4().hex[:12] - p = self._temp_prefix - self._pshell_stdout = f"{p}-stdout" - self._pshell_stderr = f"{p}-stderr" - self._pshell_status = f"{p}-status" - self._pshell_cwd = f"{p}-cwd" - self._pshell_pid_file = f"{p}-pid" - - self._shell_proc = self._spawn_shell_process() - self._shell_alive = True - - self._drain_thread = threading.Thread( - target=self._drain_shell_output, daemon=True, - ) - self._drain_thread.start() - - init_script = ( - f"export TERM=${{TERM:-dumb}}\n" - f"touch {self._pshell_stdout} {self._pshell_stderr} " - f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n" - f"echo $$ > {self._pshell_pid_file}\n" - f"pwd > {self._pshell_cwd}\n" - ) - self._send_to_shell(init_script) - - deadline = time.monotonic() + 3.0 - while time.monotonic() < deadline: - pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip() - if pid_str.isdigit(): - self._shell_pid = int(pid_str) - break - time.sleep(0.05) - else: - logger.warning("Could not read persistent shell PID") - self._shell_pid = None - - if self._shell_pid: - logger.info( - "Persistent shell started (session=%s, pid=%d)", - self._session_id, self._shell_pid, - ) - - reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip() - if reported_cwd: - self.cwd = reported_cwd - - def _cleanup_persistent_shell(self): - if self._shell_proc is None: - return - - if self._session_id: - self._cleanup_temp_files() - - try: - self._shell_proc.stdin.close() - except Exception: - pass - try: - self._shell_proc.terminate() - self._shell_proc.wait(timeout=3) - except subprocess.TimeoutExpired: - self._shell_proc.kill() - - self._shell_alive = False - self._shell_proc = None - - if hasattr(self, "_drain_thread") and self._drain_thread.is_alive(): - self._drain_thread.join(timeout=1.0) - - # ------------------------------------------------------------------ - # execute() / cleanup() — shared dispatcher, subclasses inherit - # ------------------------------------------------------------------ - - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if self.persistent: - return self._execute_persistent( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - def execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - """Always use the oneshot (non-persistent) execution path. - - This bypasses _shell_lock so it can run concurrently with a - long-running command in the persistent shell — used by - execute_code's file-based RPC polling thread. - """ - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - def cleanup(self): - if self.persistent: - self._cleanup_persistent_shell() - - # ------------------------------------------------------------------ - # Shell I/O - # ------------------------------------------------------------------ - - def _drain_shell_output(self): - try: - for _ in self._shell_proc.stdout: - pass - except Exception: - pass - self._shell_alive = False - - def _send_to_shell(self, text: str): - if not self._shell_alive or self._shell_proc is None: - return - try: - self._shell_proc.stdin.write(text) - self._shell_proc.stdin.flush() - except (BrokenPipeError, OSError): - self._shell_alive = False - - def _read_persistent_output(self) -> tuple[str, int, str]: - stdout, stderr, status_raw, cwd = self._read_temp_files( - self._pshell_stdout, self._pshell_stderr, - self._pshell_status, self._pshell_cwd, - ) - output = self._merge_output(stdout, stderr) - status = status_raw.strip() - if ":" in status: - status = status.split(":", 1)[1] - try: - exit_code = int(status.strip()) - except ValueError: - exit_code = 1 - return output, exit_code, cwd.strip() - - # ------------------------------------------------------------------ - # Execution - # ------------------------------------------------------------------ - - def _execute_persistent(self, command: str, cwd: str, *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - if not self._shell_alive: - logger.info("Persistent shell died, restarting...") - self._init_persistent_shell() - - exec_command, sudo_stdin = self._prepare_command(command) - effective_timeout = timeout or self.timeout - if stdin_data or sudo_stdin: - return self._execute_oneshot( - command, cwd, timeout=timeout, stdin_data=stdin_data, - ) - - with self._shell_lock: - return self._execute_persistent_locked( - exec_command, cwd, effective_timeout, - ) - - def _execute_persistent_locked(self, command: str, cwd: str, - timeout: int) -> dict: - work_dir = cwd or self.cwd - cmd_id = uuid.uuid4().hex[:8] - truncate = ( - f": > {self._pshell_stdout}\n" - f": > {self._pshell_stderr}\n" - f": > {self._pshell_status}\n" - ) - self._send_to_shell(truncate) - escaped = command.replace("'", "'\\''") - - ipc_script = ( - f"cd {shlex.quote(work_dir)}\n" - f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n" - f"__EC=$?\n" - f"pwd > {self._pshell_cwd}\n" - f"echo {cmd_id}:$__EC > {self._pshell_status}\n" - ) - self._send_to_shell(ipc_script) - deadline = time.monotonic() + timeout - poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms - - while True: - if is_interrupted(): - self._kill_shell_children() - output, _, _ = self._read_persistent_output() - return { - "output": output + "\n[Command interrupted]", - "returncode": 130, - } - - if time.monotonic() > deadline: - self._kill_shell_children() - output, _, _ = self._read_persistent_output() - if output: - return { - "output": output + f"\n[Command timed out after {timeout}s]", - "returncode": 124, - } - return self._timeout_result(timeout) - - if not self._shell_alive: - return { - "output": "Persistent shell died during execution", - "returncode": 1, - } - - status_content = self._read_temp_files(self._pshell_status)[0].strip() - if status_content.startswith(cmd_id + ":"): - break - - time.sleep(poll_interval) - # Exponential backoff: fast start (10ms) for quick commands, - # ramps up to 250ms for long-running commands — reduces I/O by 10-25x - # on WSL2 where polling keeps the VM hot and memory pressure high. - poll_interval = min(poll_interval * 1.5, self._poll_interval_max) - - output, exit_code, new_cwd = self._read_persistent_output() - if new_cwd: - self.cwd = new_cwd - return {"output": output, "returncode": exit_code} - - @staticmethod - def _merge_output(stdout: str, stderr: str) -> str: - parts = [] - if stdout.strip(): - parts.append(stdout.rstrip("\n")) - if stderr.strip(): - parts.append(stderr.rstrip("\n")) - return "\n".join(parts) diff --git a/tools/environments/singularity.py b/tools/environments/singularity.py index 0ea5037c84..16d1013fed 100644 --- a/tools/environments/singularity.py +++ b/tools/environments/singularity.py @@ -5,20 +5,22 @@ Supports configurable resource limits and optional filesystem persistence via writable overlay directories that survive across sessions. """ -import json import logging import os -import shlex import shutil import subprocess import threading import uuid from pathlib import Path -from typing import Dict, Optional +from typing import Optional from hermes_constants import get_hermes_home -from tools.environments.base import BaseEnvironment -from tools.interrupt import is_interrupted +from tools.environments.base import ( + BaseEnvironment, + _load_json_store, + _popen_bash, + _save_json_store, +) logger = logging.getLogger(__name__) @@ -26,11 +28,7 @@ _SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json" def _find_singularity_executable() -> str: - """Locate the apptainer or singularity CLI binary. - - Returns the executable name (``"apptainer"`` or ``"singularity"``). - Raises ``RuntimeError`` with install instructions if neither is found. - """ + """Locate the apptainer or singularity CLI binary.""" if shutil.which("apptainer"): return "apptainer" if shutil.which("singularity"): @@ -43,66 +41,34 @@ def _find_singularity_executable() -> str: def _ensure_singularity_available() -> str: - """Preflight check: resolve the executable and verify it responds. - - Returns the executable name on success. - Raises ``RuntimeError`` with an actionable message on failure. - """ + """Preflight check: resolve the executable and verify it responds.""" exe = _find_singularity_executable() - try: result = subprocess.run( - [exe, "version"], - capture_output=True, - text=True, - timeout=10, + [exe, "version"], capture_output=True, text=True, timeout=10, ) except FileNotFoundError: raise RuntimeError( - f"Singularity backend selected but the resolved executable '{exe}' " - "could not be executed. Check your installation." + f"Singularity backend selected but '{exe}' could not be executed." ) except subprocess.TimeoutExpired: - raise RuntimeError( - f"'{exe} version' timed out. The runtime may be misconfigured." - ) + raise RuntimeError(f"'{exe} version' timed out.") if result.returncode != 0: stderr = result.stderr.strip()[:200] - raise RuntimeError( - f"'{exe} version' failed (exit code {result.returncode}): {stderr}" - ) - + raise RuntimeError(f"'{exe} version' failed (exit code {result.returncode}): {stderr}") return exe -def _load_snapshots() -> Dict[str, str]: - if _SNAPSHOT_STORE.exists(): - try: - return json.loads(_SNAPSHOT_STORE.read_text()) - except Exception: - pass - return {} +def _load_snapshots() -> dict: + return _load_json_store(_SNAPSHOT_STORE) -def _save_snapshots(data: Dict[str, str]) -> None: - _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) - _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) +def _save_snapshots(data: dict) -> None: + _save_json_store(_SNAPSHOT_STORE, data) -# ------------------------------------------------------------------------- -# Singularity helpers (scratch dir, SIF cache, SIF building) -# ------------------------------------------------------------------------- - def _get_scratch_dir() -> Path: - """Get the best directory for Singularity sandboxes. - - Resolution order: - 1. TERMINAL_SCRATCH_DIR (explicit override) - 2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root) - 3. /scratch (common on HPC clusters) - 4. ~/.hermes/sandboxes/singularity (fallback) - """ custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR") if custom_scratch: scratch_path = Path(custom_scratch) @@ -124,7 +90,6 @@ def _get_scratch_dir() -> Path: def _get_apptainer_cache_dir() -> Path: - """Get the Apptainer cache directory for SIF images.""" cache_dir = os.getenv("APPTAINER_CACHEDIR") if cache_dir: cache_path = Path(cache_dir) @@ -140,11 +105,6 @@ _sif_build_lock = threading.Lock() def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: - """Get or build a SIF image from a docker:// URL. - - Returns the path unchanged if it's already a .sif file. - For docker:// URLs, checks the cache and builds if needed. - """ if image.endswith('.sif') and Path(image).exists(): return image if not image.startswith('docker://'): @@ -193,19 +153,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: return image -# ------------------------------------------------------------------------- -# SingularityEnvironment -# ------------------------------------------------------------------------- - class SingularityEnvironment(BaseEnvironment): """Hardened Singularity/Apptainer container with resource limits and persistence. - Security: --containall (isolated PID/IPC/mount namespaces, no host home mount), - --no-home, writable-tmpfs for scratch space. The container cannot see or modify - the host filesystem outside of explicitly bound paths. - - Persistence: when enabled, the writable overlay directory is preserved across - sessions so installed packages and files survive cleanup/restore. + Spawn-per-call: every execute() spawns a fresh ``apptainer exec ... bash -c`` process. + Session snapshot preserves env vars across calls. + CWD persists via in-band stdout markers. """ def __init__( @@ -227,12 +180,9 @@ class SingularityEnvironment(BaseEnvironment): self._persistent = persistent_filesystem self._task_id = task_id self._overlay_dir: Optional[Path] = None - - # Resource limits self._cpu = cpu self._memory = memory - # Persistent overlay directory if self._persistent: overlay_base = _get_scratch_dir() / "hermes-overlays" overlay_base.mkdir(parents=True, exist_ok=True) @@ -240,42 +190,26 @@ class SingularityEnvironment(BaseEnvironment): self._overlay_dir.mkdir(parents=True, exist_ok=True) self._start_instance() + self.init_session() def _start_instance(self): cmd = [self.executable, "instance", "start"] - - # Security: full isolation from host cmd.extend(["--containall", "--no-home"]) - # Writable layer if self._persistent and self._overlay_dir: - # Persistent writable overlay -- survives across restarts cmd.extend(["--overlay", str(self._overlay_dir)]) else: cmd.append("--writable-tmpfs") - # Mount credential files and skills directory (read-only). try: from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount - for mount_entry in get_credential_file_mounts(): cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"]) - logger.info( - "Singularity: binding credential %s -> %s", - mount_entry["host_path"], - mount_entry["container_path"], - ) for skills_mount in get_skills_directory_mount(): cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"]) - logger.info( - "Singularity: binding skills dir %s -> %s", - skills_mount["host_path"], - skills_mount["container_path"], - ) except Exception as e: logger.debug("Singularity: could not load credential/skills mounts: %s", e) - # Resource limits (cgroup-based, may require root or appropriate config) if self._memory > 0: cmd.extend(["--memory", f"{self._memory}M"]) if self._cpu > 0: @@ -288,94 +222,29 @@ class SingularityEnvironment(BaseEnvironment): if result.returncode != 0: raise RuntimeError(f"Failed to start instance: {result.stderr}") self._instance_started = True - logger.info("Singularity instance %s started (persistent=%s)", + logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent) except subprocess.TimeoutExpired: raise RuntimeError("Instance start timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn a bash process inside the Singularity instance.""" if not self._instance_started: - return {"output": "Instance not started", "returncode": -1} + raise RuntimeError("Singularity instance not started") - effective_timeout = timeout or self.timeout - work_dir = cwd or self.cwd - exec_command, sudo_stdin = self._prepare_command(command) - - # Merge sudo password (if any) with caller-supplied stdin_data. - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin + cmd = [self.executable, "exec", + f"instance://{self.instance_id}"] + if login: + cmd.extend(["bash", "-l", "-c", cmd_string]) else: - effective_stdin = stdin_data + cmd.extend(["bash", "-c", cmd_string]) - # apptainer exec --pwd doesn't expand ~, so prepend a cd into the command. - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - exec_command = f"cd ~ && {exec_command}" - work_dir = "/tmp" - elif work_dir.startswith("~/"): - exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}" - work_dir = "/tmp" - - cmd = [self.executable, "exec", "--pwd", work_dir, - f"instance://{self.instance_id}", - "bash", "-c", exec_command] - - try: - import time as _time - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except Exception: - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = _time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if _time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - _time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} - except Exception as e: - return {"output": f"Singularity execution error: {e}", "returncode": 1} + return _popen_bash(cmd, stdin_data) def cleanup(self): - """Stop the instance. If persistent, the overlay dir survives for next creation.""" + """Stop the instance. If persistent, the overlay dir survives.""" if self._instance_started: try: subprocess.run( @@ -387,7 +256,6 @@ class SingularityEnvironment(BaseEnvironment): logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e) self._instance_started = False - # Record overlay path for persistence restoration if self._persistent and self._overlay_dir: snapshots = _load_snapshots() snapshots[self._task_id] = str(self._overlay_dir) diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index afd28c4aff..a77eb5c9f4 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -5,13 +5,9 @@ import shlex import shutil import subprocess import tempfile -import threading -import time from pathlib import Path -from tools.environments.base import BaseEnvironment -from tools.environments.persistent_shell import PersistentShellMixin -from tools.interrupt import is_interrupted +from tools.environments.base import BaseEnvironment, _popen_bash logger = logging.getLogger(__name__) @@ -24,32 +20,22 @@ def _ensure_ssh_available() -> None: ) -class SSHEnvironment(PersistentShellMixin, BaseEnvironment): +class SSHEnvironment(BaseEnvironment): """Run commands on a remote machine over SSH. - Uses SSH ControlMaster for connection persistence so subsequent - commands are fast. Security benefit: the agent cannot modify its - own code since execution happens on a separate machine. - - Foreground commands are interruptible: the local ssh process is killed - and a remote kill is attempted over the ControlMaster socket. - - When ``persistent=True``, a single long-lived bash shell is kept alive - over SSH and state (cwd, env vars, shell variables) persists across - ``execute()`` calls. Output capture uses file-based IPC on the remote - host (stdout/stderr/exit-code written to temp files, polled via fast - ControlMaster one-shot reads). + Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process. + Session snapshot preserves env vars across calls. + CWD persists via in-band stdout markers. + Uses SSH ControlMaster for connection reuse. """ def __init__(self, host: str, user: str, cwd: str = "~", - timeout: int = 60, port: int = 22, key_path: str = "", - persistent: bool = False): + timeout: int = 60, port: int = 22, key_path: str = ""): super().__init__(cwd=cwd, timeout=timeout) self.host = host self.user = user self.port = port self.key_path = key_path - self.persistent = persistent self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh" self.control_dir.mkdir(parents=True, exist_ok=True) @@ -57,10 +43,10 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): _ensure_ssh_available() self._establish_connection() self._remote_home = self._detect_remote_home() - self._sync_skills_and_credentials() + self._last_sync_time: float = 0 # guarantees first _before_execute syncs + self._sync_files() - if self.persistent: - self._init_persistent_shell() + self.init_session() def _build_ssh_command(self, extra_args: list | None = None) -> list: cmd = ["ssh"] @@ -102,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): return home except Exception: pass - # Fallback: guess from username if self.user == "root": return "/root" return f"/home/{self.user}" - def _sync_skills_and_credentials(self) -> None: + def _sync_files(self) -> None: """Rsync skills directory and credential files to the remote host.""" try: container_base = f"{self._remote_home}/.hermes" @@ -122,7 +107,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): rsync_base.extend(["-e", ssh_opts]) dest_prefix = f"{self.user}@{self.host}" - # Sync individual credential files (remap /root/.hermes to detected home) for mount_entry in get_credential_file_mounts(): remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1) parent_dir = str(Path(remote_path).parent) @@ -136,7 +120,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): else: logger.debug("SSH: rsync credential failed: %s", result.stderr.strip()) - # Sync skill directories (local + external, remap to detected home) for skills_mount in get_skills_directory_mount(container_base=container_base): remote_path = skills_mount["container_path"] mkdir_cmd = self._build_ssh_command() @@ -154,152 +137,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment): except Exception as e: logger.debug("SSH: could not sync skills/credentials: %s", e) - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - # Incremental sync before each command so mid-session credential - # refreshes and skill updates are picked up. - self._sync_skills_and_credentials() - return super().execute(command, cwd, timeout=timeout, stdin_data=stdin_data) - - _poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency - - @property - def _temp_prefix(self) -> str: - return f"/tmp/hermes-ssh-{self._session_id}" - - def _spawn_shell_process(self) -> subprocess.Popen: + def _run_bash(self, cmd_string: str, *, login: bool = False, + timeout: int = 120, + stdin_data: str | None = None) -> subprocess.Popen: + """Spawn an SSH process that runs bash on the remote host.""" cmd = self._build_ssh_command() - cmd.append("bash -l") - return subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, - text=True, - ) - - def _read_temp_files(self, *paths: str) -> list[str]: - if len(paths) == 1: - cmd = self._build_ssh_command() - cmd.append(f"cat {paths[0]} 2>/dev/null") - try: - result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, - ) - return [result.stdout] - except (subprocess.TimeoutExpired, OSError): - return [""] - - delim = f"__HERMES_SEP_{self._session_id}__" - script = "; ".join( - f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths - ) - cmd = self._build_ssh_command() - cmd.append(script) - try: - result = subprocess.run( - cmd, capture_output=True, text=True, timeout=10, - ) - parts = result.stdout.split(delim + "\n") - return [parts[i] if i < len(parts) else "" for i in range(len(paths))] - except (subprocess.TimeoutExpired, OSError): - return [""] * len(paths) - - def _kill_shell_children(self): - if self._shell_pid is None: - return - cmd = self._build_ssh_command() - cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true") - try: - subprocess.run(cmd, capture_output=True, timeout=5) - except (subprocess.TimeoutExpired, OSError): - pass - - def _cleanup_temp_files(self): - cmd = self._build_ssh_command() - cmd.append(f"rm -f {self._temp_prefix}-*") - try: - subprocess.run(cmd, capture_output=True, timeout=5) - except (subprocess.TimeoutExpired, OSError): - pass - - def _execute_oneshot(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: - work_dir = cwd or self.cwd - exec_command, sudo_stdin = self._prepare_command(command) - # Keep ~ unquoted (for shell expansion) and quote only the subpath. - if work_dir == "~": - wrapped = f'cd ~ && {exec_command}' - elif work_dir.startswith("~/"): - wrapped = f'cd ~/{shlex.quote(work_dir[2:])} && {exec_command}' + if login: + cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)]) else: - wrapped = f'cd {shlex.quote(work_dir)} && {exec_command}' - effective_timeout = timeout or self.timeout + cmd.extend(["bash", "-c", shlex.quote(cmd_string)]) - if sudo_stdin is not None and stdin_data is not None: - effective_stdin = sudo_stdin + stdin_data - elif sudo_stdin is not None: - effective_stdin = sudo_stdin - else: - effective_stdin = stdin_data - - cmd = self._build_ssh_command() - cmd.append(wrapped) - - kwargs = self._build_run_kwargs(timeout, effective_stdin) - kwargs.pop("timeout", None) - _output_chunks = [] - proc = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL, - text=True, - ) - - if effective_stdin: - try: - proc.stdin.write(effective_stdin) - proc.stdin.close() - except (BrokenPipeError, OSError): - pass - - def _drain(): - try: - for line in proc.stdout: - _output_chunks.append(line) - except Exception: - pass - - reader = threading.Thread(target=_drain, daemon=True) - reader.start() - deadline = time.monotonic() + effective_timeout - - while proc.poll() is None: - if is_interrupted(): - proc.terminate() - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - proc.kill() - reader.join(timeout=2) - return { - "output": "".join(_output_chunks) + "\n[Command interrupted]", - "returncode": 130, - } - if time.monotonic() > deadline: - proc.kill() - reader.join(timeout=2) - return self._timeout_result(effective_timeout) - time.sleep(0.2) - - reader.join(timeout=5) - return {"output": "".join(_output_chunks), "returncode": proc.returncode} + return _popen_bash(cmd, stdin_data) def cleanup(self): - super().cleanup() if self.control_socket.exists(): try: cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 164b8a2f47..76b3e15820 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -148,6 +148,7 @@ def _handle_send(args): "slack": Platform.SLACK, "whatsapp": Platform.WHATSAPP, "signal": Platform.SIGNAL, + "bluebubbles": Platform.BLUEBUBBLES, "matrix": Platform.MATRIX, "mattermost": Platform.MATTERMOST, "homeassistant": Platform.HOMEASSISTANT, @@ -396,6 +397,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, result = await _send_feishu(pconfig, chat_id, chunk, thread_id=thread_id) elif platform == Platform.WECOM: result = await _send_wecom(pconfig.extra, chat_id, chunk) + elif platform == Platform.BLUEBUBBLES: + result = await _send_bluebubbles(pconfig.extra, chat_id, chunk) else: result = {"error": f"Direct sending not yet implemented for {platform.value}"} @@ -870,6 +873,33 @@ async def _send_wecom(extra, chat_id, message): return _error(f"WeCom send failed: {e}") +async def _send_bluebubbles(extra, chat_id, message): + """Send via BlueBubbles iMessage server using the adapter's REST API.""" + try: + from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements + if not check_bluebubbles_requirements(): + return {"error": "BlueBubbles requirements not met (need aiohttp + httpx)."} + except ImportError: + return {"error": "BlueBubbles adapter not available."} + + try: + from gateway.config import PlatformConfig + pconfig = PlatformConfig(extra=extra) + adapter = BlueBubblesAdapter(pconfig) + connected = await adapter.connect() + if not connected: + return _error("BlueBubbles: failed to connect to server") + try: + result = await adapter.send(chat_id, message) + if not result.success: + return _error(f"BlueBubbles send failed: {result.error}") + return {"success": True, "platform": "bluebubbles", "chat_id": chat_id, "message_id": result.message_id} + finally: + await adapter.disconnect() + except Exception as e: + return _error(f"BlueBubbles send failed: {e}") + + async def _send_feishu(pconfig, chat_id, message, media_files=None, thread_id=None): """Send via Feishu/Lark using the adapter's send pipeline.""" try: diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 6206c4aa69..0dc0fd5872 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -326,8 +326,123 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: if "HERMES_SPINNER_PAUSE" in os.environ: del os.environ["HERMES_SPINNER_PAUSE"] +def _safe_command_preview(command: Any, limit: int = 200) -> str: + """Return a log-safe preview for possibly-invalid command values.""" + if command is None: + return "" + if isinstance(command, str): + return command[:limit] + try: + return repr(command)[:limit] + except Exception: + return f"<{type(command).__name__}>" -def _transform_sudo_command(command: str) -> tuple[str, str | None]: +def _looks_like_env_assignment(token: str) -> bool: + """Return True when *token* is a leading shell environment assignment.""" + if "=" not in token or token.startswith("="): + return False + name, _value = token.split("=", 1) + return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name)) + + +def _read_shell_token(command: str, start: int) -> tuple[str, int]: + """Read one shell token, preserving quotes/escapes, starting at *start*.""" + i = start + n = len(command) + + while i < n: + ch = command[i] + if ch.isspace() or ch in ";|&()": + break + if ch == "'": + i += 1 + while i < n and command[i] != "'": + i += 1 + if i < n: + i += 1 + continue + if ch == '"': + i += 1 + while i < n: + inner = command[i] + if inner == "\\" and i + 1 < n: + i += 2 + continue + if inner == '"': + i += 1 + break + i += 1 + continue + if ch == "\\" and i + 1 < n: + i += 2 + continue + i += 1 + + return command[start:i], i + + +def _rewrite_real_sudo_invocations(command: str) -> tuple[str, bool]: + """Rewrite only real unquoted sudo command words, not plain text mentions.""" + out: list[str] = [] + i = 0 + n = len(command) + command_start = True + found = False + + while i < n: + ch = command[i] + + if ch.isspace(): + out.append(ch) + if ch == "\n": + command_start = True + i += 1 + continue + + if ch == "#" and command_start: + comment_end = command.find("\n", i) + if comment_end == -1: + out.append(command[i:]) + break + out.append(command[i:comment_end]) + i = comment_end + continue + + if command.startswith("&&", i) or command.startswith("||", i) or command.startswith(";;", i): + out.append(command[i:i + 2]) + i += 2 + command_start = True + continue + + if ch in ";|&(": + out.append(ch) + i += 1 + command_start = True + continue + + if ch == ")": + out.append(ch) + i += 1 + command_start = False + continue + + token, next_i = _read_shell_token(command, i) + if command_start and token == "sudo": + out.append("sudo -S -p ''") + found = True + else: + out.append(token) + + if command_start and _looks_like_env_assignment(token): + command_start = True + else: + command_start = False + i = next_i + + return "".join(out), found + + +def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None]: """ Transform sudo commands to use -S flag if SUDO_PASSWORD is available. @@ -362,37 +477,26 @@ def _transform_sudo_command(command: str) -> tuple[str, str | None]: Command runs as-is (fails gracefully with "sudo: a password is required"). """ global _cached_sudo_password - import re - # Check if command even contains sudo - if not re.search(r'\bsudo\b', command): - return command, None # No sudo in command, nothing to do + if command is None: + return None, None + transformed, has_real_sudo = _rewrite_real_sudo_invocations(command) + if not has_real_sudo: + return command, None - # Try to get password from: env var -> session cache -> interactive prompt - sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password + has_configured_password = "SUDO_PASSWORD" in os.environ + sudo_password = os.environ.get("SUDO_PASSWORD", "") if has_configured_password else _cached_sudo_password - if not sudo_password: - # No password configured - check if we're in interactive mode - if os.getenv("HERMES_INTERACTIVE"): - # Prompt user for password - sudo_password = _prompt_for_sudo_password(timeout_seconds=45) - if sudo_password: - _cached_sudo_password = sudo_password # Cache for session + if not has_configured_password and not sudo_password and os.getenv("HERMES_INTERACTIVE"): + sudo_password = _prompt_for_sudo_password(timeout_seconds=45) + if sudo_password: + _cached_sudo_password = sudo_password - if not sudo_password: - return command, None # No password, let it fail gracefully + if has_configured_password or sudo_password: + # Trailing newline is required: sudo -S reads one line for the password. + return transformed, sudo_password + "\n" - def replace_sudo(match): - # Replace bare 'sudo' with 'sudo -S -p ""'. - # The password is returned as sudo_stdin and must be written to the - # process's stdin pipe by the caller — it never appears in any - # command-line argument or shell string. - return "sudo -S -p ''" - - # Match 'sudo' at word boundaries (not 'visudo' or 'sudoers') - transformed = re.sub(r'\bsudo\b', replace_sudo, command) - # Trailing newline is required: sudo -S reads one line for the password. - return transformed, sudo_password + "\n" + return command, None # Environment classes now live in tools/environments/ @@ -611,9 +715,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, docker_env = cc.get("docker_env", {}) if env_type == "local": - lc = local_config or {} - return _LocalEnvironment(cwd=cwd, timeout=timeout, - persistent=lc.get("persistent", False)) + return _LocalEnvironment(cwd=cwd, timeout=timeout) elif env_type == "docker": return _DockerEnvironment( @@ -705,7 +807,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, key_path=ssh_config.get("key", ""), cwd=cwd, timeout=timeout, - persistent=ssh_config.get("persistent", False), ) else: @@ -817,6 +918,23 @@ def get_active_env(task_id: str): return _active_environments.get(task_id) +def is_persistent_env(task_id: str) -> bool: + """Return True if the active environment for task_id is configured for + cross-turn persistence (``persistent_filesystem=True``). + + Used by the agent loop to skip per-turn teardown for backends whose whole + point is to survive between turns (docker with ``container_persistent``, + daytona, modal, etc.). Non-persistent backends (e.g. Morph) still get torn + down at end-of-turn to prevent leakage. The idle reaper + (``_cleanup_inactive_envs``) handles persistent envs once they exceed + ``terminal.lifetime_seconds``. + """ + env = get_active_env(task_id) + if env is None: + return False + return bool(getattr(env, "_persistent", False)) + + def get_active_environments_info() -> Dict[str, Any]: """Get information about currently active environments.""" info = { @@ -1036,6 +1154,18 @@ def terminal_tool( # Note: force parameter is internal only, not exposed to model API """ try: + if not isinstance(command, str): + logger.warning( + "Rejected invalid terminal command value: %s", + type(command).__name__, + ) + return json.dumps({ + "output": "", + "exit_code": -1, + "error": f"Invalid command: expected string, got {type(command).__name__}", + "status": "error", + }, ensure_ascii=False) + # Get configuration config = _get_env_config() env_type = config["env_type"] @@ -1193,7 +1323,7 @@ def terminal_tool( workdir_error = _validate_workdir(workdir) if workdir_error: logger.warning("Blocked dangerous workdir: %s (command: %s)", - workdir[:200], command[:200]) + workdir[:200], _safe_command_preview(command)) return json.dumps({ "output": "", "exit_code": -1, @@ -1333,12 +1463,12 @@ def terminal_tool( retry_count += 1 wait_time = 2 ** retry_count logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) + wait_time, retry_count, max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type) time.sleep(wait_time) continue logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) + max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type) return json.dumps({ "output": "", "exit_code": -1, diff --git a/toolsets.py b/toolsets.py index 2a359b60a7..a786ee7c66 100644 --- a/toolsets.py +++ b/toolsets.py @@ -311,6 +311,12 @@ TOOLSETS = { "includes": [] }, + "hermes-bluebubbles": { + "description": "BlueBubbles iMessage bot toolset - Apple iMessage via local BlueBubbles server", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-homeassistant": { "description": "Home Assistant bot toolset - smart home event monitoring and control", "tools": _HERMES_CORE_TOOLS, @@ -368,7 +374,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"] } } diff --git a/website/docs/developer-guide/architecture.md b/website/docs/developer-guide/architecture.md index c08161b32f..38fbfb138c 100644 --- a/website/docs/developer-guide/architecture.md +++ b/website/docs/developer-guide/architecture.md @@ -116,9 +116,9 @@ hermes-agent/ │ ├── mirror.py # Cross-session message mirroring │ ├── status.py # Token locks, profile-scoped process tracking │ ├── builtin_hooks/ # Always-registered hooks -│ └── platforms/ # 14 adapters: telegram, discord, slack, whatsapp, +│ └── platforms/ # 15 adapters: telegram, discord, slack, whatsapp, │ # signal, matrix, mattermost, email, sms, -│ # dingtalk, feishu, wecom, homeassistant, webhook +│ # dingtalk, feishu, wecom, bluebubbles, homeassistant, webhook │ ├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains) ├── cron/ # Scheduler (jobs.py, scheduler.py) diff --git a/website/docs/developer-guide/cron-internals.md b/website/docs/developer-guide/cron-internals.md index cc8435dbee..2f14d4e1a5 100644 --- a/website/docs/developer-guide/cron-internals.md +++ b/website/docs/developer-guide/cron-internals.md @@ -153,6 +153,7 @@ Cron job results can be delivered to any supported platform: | DingTalk | `dingtalk` | Deliver to DingTalk | | Feishu | `feishu` | Deliver to Feishu | | WeCom | `wecom` | Deliver to WeCom | +| BlueBubbles | `bluebubbles` | Deliver to iMessage via BlueBubbles | For Telegram topics, use the format `telegram::` (e.g., `telegram:-1001234567890:17585`). diff --git a/website/docs/developer-guide/gateway-internals.md b/website/docs/developer-guide/gateway-internals.md index 1371bdd340..cf25cecd9a 100644 --- a/website/docs/developer-guide/gateway-internals.md +++ b/website/docs/developer-guide/gateway-internals.md @@ -160,6 +160,7 @@ gateway/platforms/ ├── dingtalk.py # DingTalk WebSocket ├── feishu.py # Feishu/Lark WebSocket or webhook ├── wecom.py # WeCom (WeChat Work) callback +├── bluebubbles.py # Apple iMessage via BlueBubbles macOS server ├── webhook.py # Inbound/outbound webhook adapter ├── api_server.py # REST API server adapter └── homeassistant.py # Home Assistant conversation integration diff --git a/website/docs/getting-started/nix-setup.md b/website/docs/getting-started/nix-setup.md index 8bd1924053..4db4939868 100644 --- a/website/docs/getting-started/nix-setup.md +++ b/website/docs/getting-started/nix-setup.md @@ -74,7 +74,7 @@ This module requires NixOS. For non-NixOS systems (macOS, other Linux distros), # /etc/nixos/flake.nix (or your system flake) { inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; hermes-agent.url = "github:NousResearch/hermes-agent"; }; diff --git a/website/docs/index.md b/website/docs/index.md index f4b5378f4c..0f180673ac 100644 --- a/website/docs/index.md +++ b/website/docs/index.md @@ -46,7 +46,7 @@ It's not a coding copilot tethered to an IDE or a chatbot wrapper around a singl - **A closed learning loop** — Agent-curated memory with periodic nudges, autonomous skill creation, skill self-improvement during use, FTS5 cross-session recall with LLM summarization, and [Honcho](https://github.com/plastic-labs/honcho) dialectic user modeling - **Runs anywhere, not just your laptop** — 6 terminal backends: local, Docker, SSH, Daytona, Singularity, Modal. Daytona and Modal offer serverless persistence — your environment hibernates when idle, costing nearly nothing -- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, Home Assistant — 14+ platforms from one gateway +- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, BlueBubbles, Home Assistant — 15+ platforms from one gateway - **Built by model trainers** — Created by [Nous Research](https://nousresearch.com), the lab behind Hermes, Nomos, and Psyche. Works with [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI, or any endpoint - **Scheduled automations** — Built-in cron with delivery to any platform - **Delegates & parallelizes** — Spawn isolated subagents for parallel workstreams. Programmatic Tool Calling via `execute_code` collapses multi-step pipelines into single inference calls diff --git a/website/docs/integrations/index.md b/website/docs/integrations/index.md index ce103f1cc8..e6fe54f776 100644 --- a/website/docs/integrations/index.md +++ b/website/docs/integrations/index.md @@ -80,9 +80,9 @@ Speech-to-text supports three providers: local Whisper (free, runs on-device), G ## Messaging Platforms -Hermes runs as a gateway bot on 14+ messaging platforms, all configured through the same `gateway` subsystem: +Hermes runs as a gateway bot on 15+ messaging platforms, all configured through the same `gateway` subsystem: -- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** +- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[BlueBubbles](/docs/user-guide/messaging/bluebubbles)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** See the [Messaging Gateway overview](/docs/user-guide/messaging) for the platform comparison table and setup guide. diff --git a/website/docs/integrations/providers.md b/website/docs/integrations/providers.md index 74d4e631ae..fbfa69ade6 100644 --- a/website/docs/integrations/providers.md +++ b/website/docs/integrations/providers.md @@ -230,7 +230,7 @@ model: ``` :::warning Legacy env vars -`OPENAI_BASE_URL` and `LLM_MODEL` in `.env` are **deprecated**. `OPENAI_BASE_URL` is no longer consulted for endpoint resolution — `config.yaml` is the single source of truth. The CLI ignores `LLM_MODEL` entirely (only the gateway reads it as a fallback). Use `hermes model` or edit `config.yaml` directly — both persist correctly across restarts and Docker containers. +`OPENAI_BASE_URL` and `LLM_MODEL` in `.env` are **removed**. Neither is read by any part of Hermes — `config.yaml` is the single source of truth for model and endpoint configuration. If you have stale entries in your `.env`, they are automatically cleared on the next `hermes setup` or config migration. Use `hermes model` or edit `config.yaml` directly. ::: Both approaches persist to `config.yaml`, which is the source of truth for model, provider, and base URL. diff --git a/website/docs/reference/cli-commands.md b/website/docs/reference/cli-commands.md index 55983b1c69..a7362b06ff 100644 --- a/website/docs/reference/cli-commands.md +++ b/website/docs/reference/cli-commands.md @@ -43,6 +43,8 @@ hermes [global-options] [subcommand/options] | `hermes cron` | Inspect and tick the cron scheduler. | | `hermes webhook` | Manage dynamic webhook subscriptions for event-driven activation. | | `hermes doctor` | Diagnose config and dependency issues. | +| `hermes dump` | Copy-pasteable setup summary for support/debugging. | +| `hermes logs` | View, tail, and filter agent/gateway/error log files. | | `hermes config` | Show, edit, migrate, and query configuration files. | | `hermes pairing` | Approve or revoke messaging pairing codes. | | `hermes skills` | Browse, install, publish, audit, and configure skills. | @@ -272,6 +274,149 @@ hermes doctor [--fix] |--------|-------------| | `--fix` | Attempt automatic repairs where possible. | +## `hermes dump` + +```bash +hermes dump [--show-keys] +``` + +Outputs a compact, plain-text summary of your entire Hermes setup. Designed to be copy-pasted into Discord, GitHub issues, or Telegram when asking for support — no ANSI colors, no special formatting, just data. + +| Option | Description | +|--------|-------------| +| `--show-keys` | Show redacted API key prefixes (first and last 4 characters) instead of just `set`/`not set`. | + +### What it includes + +| Section | Details | +|---------|---------| +| **Header** | Hermes version, release date, git commit hash | +| **Environment** | OS, Python version, OpenAI SDK version | +| **Identity** | Active profile name, HERMES_HOME path | +| **Model** | Configured default model and provider | +| **Terminal** | Backend type (local, docker, ssh, etc.) | +| **API keys** | Presence check for all 22 provider/tool API keys | +| **Features** | Enabled toolsets, MCP server count, memory provider | +| **Services** | Gateway status, configured messaging platforms | +| **Workload** | Cron job counts, installed skill count | +| **Config overrides** | Any config values that differ from defaults | + +### Example output + +``` +--- hermes dump --- +version: 0.8.0 (2026.4.8) [af4abd2f] +os: Linux 6.14.0-37-generic x86_64 +python: 3.11.14 +openai_sdk: 2.24.0 +profile: default +hermes_home: ~/.hermes +model: anthropic/claude-opus-4.6 +provider: openrouter +terminal: local + +api_keys: + openrouter set + openai not set + anthropic set + nous not set + firecrawl set + ... + +features: + toolsets: all + mcp_servers: 0 + memory_provider: built-in + gateway: running (systemd) + platforms: telegram, discord + cron_jobs: 3 active / 5 total + skills: 42 + +config_overrides: + agent.max_turns: 250 + compression.threshold: 0.85 + display.streaming: True +--- end dump --- +``` + +### When to use + +- Reporting a bug on GitHub — paste the dump into your issue +- Asking for help in Discord — share it in a code block +- Comparing your setup to someone else's +- Quick sanity check when something isn't working + +:::tip +`hermes dump` is specifically designed for sharing. For interactive diagnostics, use `hermes doctor`. For a visual overview, use `hermes status`. +::: + +## `hermes logs` + +```bash +hermes logs [log_name] [options] +``` + +View, tail, and filter Hermes log files. All logs are stored in `~/.hermes/logs/` (or `/logs/` for non-default profiles). + +### Log files + +| Name | File | What it captures | +|------|------|-----------------| +| `agent` (default) | `agent.log` | All agent activity — API calls, tool dispatch, session lifecycle (INFO and above) | +| `errors` | `errors.log` | Warnings and errors only — a filtered subset of agent.log | +| `gateway` | `gateway.log` | Messaging gateway activity — platform connections, message dispatch, webhook events | + +### Options + +| Option | Description | +|--------|-------------| +| `log_name` | Which log to view: `agent` (default), `errors`, `gateway`, or `list` to show available files with sizes. | +| `-n`, `--lines ` | Number of lines to show (default: 50). | +| `-f`, `--follow` | Follow the log in real time, like `tail -f`. Press Ctrl+C to stop. | +| `--level ` | Minimum log level to show: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`. | +| `--session ` | Filter lines containing a session ID substring. | +| `--since