mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
b2ea9b4176
112 changed files with 9087 additions and 2195 deletions
20
.github/workflows/docker-publish.yml
vendored
20
.github/workflows/docker-publish.yml
vendored
|
|
@ -8,6 +8,9 @@ on:
|
|||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: docker-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
|
@ -17,22 +20,29 @@ jobs:
|
|||
# Only run on the upstream repository, not on forks
|
||||
if: github.repository == 'NousResearch/hermes-agent'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build image
|
||||
# Build amd64 only so we can `load` the image for smoke testing.
|
||||
# `load: true` cannot export a multi-arch manifest to the local daemon.
|
||||
# The multi-arch build follows on push to main / release.
|
||||
- name: Build image (amd64, smoke test)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
load: true
|
||||
platforms: linux/amd64
|
||||
tags: nousresearch/hermes-agent:test
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
|
@ -51,26 +61,28 @@ jobs:
|
|||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Push image (main branch)
|
||||
- name: Push multi-arch image (main branch)
|
||||
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
nousresearch/hermes-agent:latest
|
||||
nousresearch/hermes-agent:${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Push image (release)
|
||||
- name: Push multi-arch image (release)
|
||||
if: github.event_name == 'release'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: |
|
||||
nousresearch/hermes-agent:latest
|
||||
nousresearch/hermes-agent:${{ github.event.release.tag_name }}
|
||||
|
|
|
|||
4
.github/workflows/docs-site-checks.yml
vendored
4
.github/workflows/docs-site-checks.yml
vendored
|
|
@ -27,8 +27,8 @@ jobs:
|
|||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: python -m pip install ascii-guard pyyaml
|
||||
- name: Install ascii-guard
|
||||
run: python -m pip install ascii-guard==2.3.0 pyyaml==6.0.3
|
||||
|
||||
- name: Extract skill metadata for dashboard
|
||||
run: python3 website/scripts/extract-skills.py
|
||||
|
|
|
|||
4
.github/workflows/nix.yml
vendored
4
.github/workflows/nix.yml
vendored
|
|
@ -27,8 +27,8 @@ jobs:
|
|||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
- uses: DeterminateSystems/nix-installer-action@ef8a148080ab6020fd15196c2084a2eea5ff2d25 # v22
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@565684385bcd71bad329742eefe8d12f2e765b39 # v13
|
||||
- name: Check flake
|
||||
if: runner.os == 'Linux'
|
||||
run: nix flake check --print-build-logs
|
||||
|
|
|
|||
|
|
@ -629,11 +629,19 @@ def _nous_base_url() -> str:
|
|||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store.
|
||||
|
||||
If a credential pool exists but currently has no selectable runtime entry
|
||||
(for example all pool slots are marked exhausted), fall back to the
|
||||
profile's auth.json token instead of hard-failing. This keeps explicit
|
||||
fallback-to-Codex working when the pool state is stale but the stored OAuth
|
||||
token is still valid.
|
||||
"""
|
||||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
token = _pool_runtime_api_key(entry)
|
||||
return token or None
|
||||
if token:
|
||||
return token
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
|
|
@ -894,9 +902,13 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
|||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
codex_token = _pool_runtime_api_key(entry)
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL
|
||||
if codex_token:
|
||||
base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL
|
||||
else:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _CODEX_AUX_BASE_URL
|
||||
else:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
|
|
|
|||
|
|
@ -154,12 +154,15 @@ class ContextCompressor:
|
|||
|
||||
def _prune_old_tool_results(
|
||||
self, messages: List[Dict[str, Any]], protect_tail_count: int,
|
||||
protect_tail_tokens: int | None = None,
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace old tool result contents with a short placeholder.
|
||||
|
||||
Walks backward from the end, protecting the most recent
|
||||
``protect_tail_count`` messages. Older tool results get their
|
||||
content replaced with a placeholder string.
|
||||
Walks backward from the end, protecting the most recent messages that
|
||||
fall within ``protect_tail_tokens`` (when provided) OR the last
|
||||
``protect_tail_count`` messages (backward-compatible default).
|
||||
When both are given, the token budget takes priority and the message
|
||||
count acts as a hard minimum floor.
|
||||
|
||||
Returns (pruned_messages, pruned_count).
|
||||
"""
|
||||
|
|
@ -168,7 +171,29 @@ class ContextCompressor:
|
|||
|
||||
result = [m.copy() for m in messages]
|
||||
pruned = 0
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
# Determine the prune boundary
|
||||
if protect_tail_tokens is not None and protect_tail_tokens > 0:
|
||||
# Token-budget approach: walk backward accumulating tokens
|
||||
accumulated = 0
|
||||
boundary = len(result)
|
||||
min_protect = min(protect_tail_count, len(result) - 1)
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
msg = result[i]
|
||||
content_len = len(msg.get("content") or "")
|
||||
msg_tokens = content_len // _CHARS_PER_TOKEN + 10
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > protect_tail_tokens and (len(result) - i) >= min_protect:
|
||||
boundary = i
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
boundary = i
|
||||
prune_boundary = max(boundary, len(result) - min_protect)
|
||||
else:
|
||||
prune_boundary = len(result) - protect_tail_count
|
||||
|
||||
for i in range(prune_boundary):
|
||||
msg = result[i]
|
||||
|
|
@ -199,30 +224,39 @@ class ContextCompressor:
|
|||
budget = int(content_tokens * _SUMMARY_RATIO)
|
||||
return max(_MIN_SUMMARY_TOKENS, min(budget, self.max_summary_tokens))
|
||||
|
||||
# Truncation limits for the summarizer input. These bound how much of
|
||||
# each message the summary model sees — the budget is the *summary*
|
||||
# model's context window, not the main model's.
|
||||
_CONTENT_MAX = 6000 # total chars per message body
|
||||
_CONTENT_HEAD = 4000 # chars kept from the start
|
||||
_CONTENT_TAIL = 1500 # chars kept from the end
|
||||
_TOOL_ARGS_MAX = 1500 # tool call argument chars
|
||||
_TOOL_ARGS_HEAD = 1200 # kept from the start of tool args
|
||||
|
||||
def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str:
|
||||
"""Serialize conversation turns into labeled text for the summarizer.
|
||||
|
||||
Includes tool call arguments and result content (up to 3000 chars
|
||||
per message) so the summarizer can preserve specific details like
|
||||
file paths, commands, and outputs.
|
||||
Includes tool call arguments and result content (up to
|
||||
``_CONTENT_MAX`` chars per message) so the summarizer can preserve
|
||||
specific details like file paths, commands, and outputs.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
# Tool results: keep more content than before (3000 chars)
|
||||
# Tool results: keep enough content for the summarizer
|
||||
if role == "tool":
|
||||
tool_id = msg.get("tool_call_id", "")
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
parts.append(f"[TOOL RESULT {tool_id}]: {content}")
|
||||
continue
|
||||
|
||||
# Assistant messages: include tool call names AND arguments
|
||||
if role == "assistant":
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
|
|
@ -232,8 +266,8 @@ class ContextCompressor:
|
|||
name = fn.get("name", "?")
|
||||
args = fn.get("arguments", "")
|
||||
# Truncate long arguments but keep enough for context
|
||||
if len(args) > 500:
|
||||
args = args[:400] + "..."
|
||||
if len(args) > self._TOOL_ARGS_MAX:
|
||||
args = args[:self._TOOL_ARGS_HEAD] + "..."
|
||||
tc_parts.append(f" {name}({args})")
|
||||
else:
|
||||
fn = getattr(tc, "function", None)
|
||||
|
|
@ -244,8 +278,8 @@ class ContextCompressor:
|
|||
continue
|
||||
|
||||
# User and other roles
|
||||
if len(content) > 3000:
|
||||
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
|
||||
if len(content) > self._CONTENT_MAX:
|
||||
content = content[:self._CONTENT_HEAD] + "\n...[truncated]...\n" + content[-self._CONTENT_TAIL:]
|
||||
parts.append(f"[{role.upper()}]: {content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
|
@ -310,6 +344,9 @@ Update the summary using this exact structure. PRESERVE all existing information
|
|||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
## Tools & Patterns
|
||||
[Which tools were used, how they were used effectively, and any tool-specific discoveries. Accumulate across compactions.]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
|
@ -348,6 +385,9 @@ Use this exact structure:
|
|||
## Critical Context
|
||||
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
|
||||
|
||||
## Tools & Patterns
|
||||
[Which tools were used, how they were used effectively, and any tool-specific discoveries (e.g., preferred flags, working invocations, successful command patterns)]
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
|
@ -518,13 +558,20 @@ Write only the summary body. Do not include any preamble or prefix."""
|
|||
derived from ``summary_target_ratio * context_length``, so it
|
||||
scales automatically with the model's context window.
|
||||
|
||||
Never cuts inside a tool_call/result group. Falls back to the old
|
||||
``protect_last_n`` if the budget would protect fewer messages.
|
||||
Token budget is the primary criterion. A hard minimum of 3 messages
|
||||
is always protected, but the budget is allowed to exceed by up to
|
||||
1.5x to avoid cutting inside an oversized message (tool output, file
|
||||
read, etc.). If even the minimum 3 messages exceed 1.5x the budget
|
||||
the cut is placed right after the head so compression still runs.
|
||||
|
||||
Never cuts inside a tool_call/result group.
|
||||
"""
|
||||
if token_budget is None:
|
||||
token_budget = self.tail_token_budget
|
||||
n = len(messages)
|
||||
min_tail = self.protect_last_n
|
||||
# Hard minimum: always keep at least 3 messages in the tail
|
||||
min_tail = min(3, n - head_end - 1) if n - head_end > 1 else 0
|
||||
soft_ceiling = int(token_budget * 1.5)
|
||||
accumulated = 0
|
||||
cut_idx = n # start from beyond the end
|
||||
|
||||
|
|
@ -537,21 +584,21 @@ Write only the summary body. Do not include any preamble or prefix."""
|
|||
if isinstance(tc, dict):
|
||||
args = tc.get("function", {}).get("arguments", "")
|
||||
msg_tokens += len(args) // _CHARS_PER_TOKEN
|
||||
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
|
||||
# Stop once we exceed the soft ceiling (unless we haven't hit min_tail yet)
|
||||
if accumulated + msg_tokens > soft_ceiling and (n - i) >= min_tail:
|
||||
break
|
||||
accumulated += msg_tokens
|
||||
cut_idx = i
|
||||
|
||||
# Ensure we protect at least protect_last_n messages
|
||||
# Ensure we protect at least min_tail messages
|
||||
fallback_cut = n - min_tail
|
||||
if cut_idx > fallback_cut:
|
||||
cut_idx = fallback_cut
|
||||
|
||||
# If the token budget would protect everything (small conversations),
|
||||
# fall back to the fixed protect_last_n approach so compression can
|
||||
# still remove middle turns.
|
||||
# force a cut after the head so compression can still remove middle turns.
|
||||
if cut_idx <= head_end:
|
||||
cut_idx = fallback_cut
|
||||
cut_idx = max(fallback_cut, head_end + 1)
|
||||
|
||||
# Align to avoid splitting tool groups
|
||||
cut_idx = self._align_boundary_backward(messages, cut_idx)
|
||||
|
|
@ -576,12 +623,13 @@ Write only the summary body. Do not include any preamble or prefix."""
|
|||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
_min_for_compress = self.protect_first_n + 3 + 1
|
||||
if n_messages <= _min_for_compress:
|
||||
if not self.quiet_mode:
|
||||
logger.warning(
|
||||
"Cannot compress: only %d messages (need > %d)",
|
||||
n_messages,
|
||||
self.protect_first_n + self.protect_last_n + 1,
|
||||
n_messages, _min_for_compress,
|
||||
)
|
||||
return messages
|
||||
|
||||
|
|
@ -589,7 +637,8 @@ Write only the summary body. Do not include any preamble or prefix."""
|
|||
|
||||
# Phase 1: Prune old tool results (cheap, no LLM call)
|
||||
messages, pruned_count = self._prune_old_tool_results(
|
||||
messages, protect_tail_count=self.protect_last_n * 3,
|
||||
messages, protect_tail_count=self.protect_last_n,
|
||||
protect_tail_tokens=self.tail_token_budget,
|
||||
)
|
||||
if pruned_count and not self.quiet_mode:
|
||||
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
|
||||
|
|
|
|||
|
|
@ -64,10 +64,10 @@ SUPPORTED_POOL_STRATEGIES = {
|
|||
}
|
||||
|
||||
# Cooldown before retrying an exhausted credential.
|
||||
# 429 (rate-limited) cools down faster since quotas reset frequently.
|
||||
# 402 (billing/quota) and other codes use a longer default.
|
||||
# 429 (rate-limited) and 402 (billing/quota) both cool down after 1 hour.
|
||||
# Provider-supplied reset_at timestamps override these defaults.
|
||||
EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour
|
||||
EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours
|
||||
EXHAUSTED_TTL_DEFAULT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Pool key prefix for custom OpenAI-compatible endpoints.
|
||||
# Custom endpoints all share provider='custom' but are keyed by their
|
||||
|
|
|
|||
789
agent/error_classifier.py
Normal file
789
agent/error_classifier.py
Normal file
|
|
@ -0,0 +1,789 @@
|
|||
"""API error classification for smart failover and recovery.
|
||||
|
||||
Provides a structured taxonomy of API errors and a priority-ordered
|
||||
classification pipeline that determines the correct recovery action
|
||||
(retry, rotate credential, fallback to another provider, compress
|
||||
context, or abort).
|
||||
|
||||
Replaces scattered inline string-matching with a centralized classifier
|
||||
that the main retry loop in run_agent.py consults for every API failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Error taxonomy ──────────────────────────────────────────────────────
|
||||
|
||||
class FailoverReason(enum.Enum):
|
||||
"""Why an API call failed — determines recovery strategy."""
|
||||
|
||||
# Authentication / authorization
|
||||
auth = "auth" # Transient auth (401/403) — refresh/rotate
|
||||
auth_permanent = "auth_permanent" # Auth failed after refresh — abort
|
||||
|
||||
# Billing / quota
|
||||
billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately
|
||||
rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate
|
||||
|
||||
# Server-side
|
||||
overloaded = "overloaded" # 503/529 — provider overloaded, backoff
|
||||
server_error = "server_error" # 500/502 — internal server error, retry
|
||||
|
||||
# Transport
|
||||
timeout = "timeout" # Connection/read timeout — rebuild client + retry
|
||||
|
||||
# Context / payload
|
||||
context_overflow = "context_overflow" # Context too large — compress, not failover
|
||||
payload_too_large = "payload_too_large" # 413 — compress payload
|
||||
|
||||
# Model
|
||||
model_not_found = "model_not_found" # 404 or invalid model — fallback to different model
|
||||
|
||||
# Request format
|
||||
format_error = "format_error" # 400 bad request — abort or strip + retry
|
||||
|
||||
# Provider-specific
|
||||
thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid
|
||||
long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate
|
||||
|
||||
# Catch-all
|
||||
unknown = "unknown" # Unclassifiable — retry with backoff
|
||||
|
||||
|
||||
# ── Classification result ───────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ClassifiedError:
|
||||
"""Structured classification of an API error with recovery hints."""
|
||||
|
||||
reason: FailoverReason
|
||||
status_code: Optional[int] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
message: str = ""
|
||||
error_context: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Recovery action hints — the retry loop checks these instead of
|
||||
# re-classifying the error itself.
|
||||
retryable: bool = True
|
||||
should_compress: bool = False
|
||||
should_rotate_credential: bool = False
|
||||
should_fallback: bool = False
|
||||
|
||||
@property
|
||||
def is_auth(self) -> bool:
|
||||
return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent)
|
||||
|
||||
@property
|
||||
def is_transient(self) -> bool:
|
||||
"""Error is expected to resolve on retry (with or without backoff)."""
|
||||
return self.reason in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.server_error,
|
||||
FailoverReason.timeout,
|
||||
FailoverReason.unknown,
|
||||
)
|
||||
|
||||
|
||||
# ── Provider-specific patterns ──────────────────────────────────────────
|
||||
|
||||
# Patterns that indicate billing exhaustion (not transient rate limit)
|
||||
_BILLING_PATTERNS = [
|
||||
"insufficient credits",
|
||||
"insufficient_quota",
|
||||
"credit balance",
|
||||
"credits have been exhausted",
|
||||
"top up your credits",
|
||||
"payment required",
|
||||
"billing hard limit",
|
||||
"exceeded your current quota",
|
||||
"account is deactivated",
|
||||
"plan does not include",
|
||||
]
|
||||
|
||||
# Patterns that indicate rate limiting (transient, will resolve)
|
||||
_RATE_LIMIT_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"throttled",
|
||||
"requests per minute",
|
||||
"tokens per minute",
|
||||
"requests per day",
|
||||
"try again in",
|
||||
"please retry after",
|
||||
"resource_exhausted",
|
||||
]
|
||||
|
||||
# Usage-limit patterns that need disambiguation (could be billing OR rate_limit)
|
||||
_USAGE_LIMIT_PATTERNS = [
|
||||
"usage limit",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"key limit exceeded",
|
||||
]
|
||||
|
||||
# Patterns confirming usage limit is transient (not billing)
|
||||
_USAGE_LIMIT_TRANSIENT_SIGNALS = [
|
||||
"try again",
|
||||
"retry",
|
||||
"resets at",
|
||||
"reset in",
|
||||
"wait",
|
||||
"requests remaining",
|
||||
"periodic",
|
||||
"window",
|
||||
]
|
||||
|
||||
# Payload-too-large patterns detected from message text (no status_code attr).
|
||||
# Proxies and some backends embed the HTTP status in the error message.
|
||||
_PAYLOAD_TOO_LARGE_PATTERNS = [
|
||||
"request entity too large",
|
||||
"payload too large",
|
||||
"error code: 413",
|
||||
]
|
||||
|
||||
# Context overflow patterns
|
||||
_CONTEXT_OVERFLOW_PATTERNS = [
|
||||
"context length",
|
||||
"context size",
|
||||
"maximum context",
|
||||
"token limit",
|
||||
"too many tokens",
|
||||
"reduce the length",
|
||||
"exceeds the limit",
|
||||
"context window",
|
||||
"prompt is too long",
|
||||
"prompt exceeds max length",
|
||||
"max_tokens",
|
||||
"maximum number of tokens",
|
||||
# Chinese error messages (some providers return these)
|
||||
"超过最大长度",
|
||||
"上下文长度",
|
||||
]
|
||||
|
||||
# Model not found patterns
|
||||
_MODEL_NOT_FOUND_PATTERNS = [
|
||||
"is not a valid model",
|
||||
"invalid model",
|
||||
"model not found",
|
||||
"model_not_found",
|
||||
"does not exist",
|
||||
"no such model",
|
||||
"unknown model",
|
||||
"unsupported model",
|
||||
]
|
||||
|
||||
# Auth patterns (non-status-code signals)
|
||||
_AUTH_PATTERNS = [
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"invalid token",
|
||||
"token expired",
|
||||
"token revoked",
|
||||
"access denied",
|
||||
]
|
||||
|
||||
# Anthropic thinking block signature patterns
|
||||
_THINKING_SIG_PATTERNS = [
|
||||
"signature", # Combined with "thinking" check
|
||||
]
|
||||
|
||||
# Transport error type names
|
||||
_TRANSPORT_ERROR_TYPES = frozenset({
|
||||
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
|
||||
"ConnectError", "RemoteProtocolError",
|
||||
"ConnectionError", "ConnectionResetError",
|
||||
"ConnectionAbortedError", "BrokenPipeError",
|
||||
"TimeoutError", "ReadError",
|
||||
"ServerDisconnectedError",
|
||||
# OpenAI SDK errors (not subclasses of Python builtins)
|
||||
"APIConnectionError",
|
||||
"APITimeoutError",
|
||||
})
|
||||
|
||||
# Server disconnect patterns (no status code, but transport-level)
|
||||
_SERVER_DISCONNECT_PATTERNS = [
|
||||
"server disconnected",
|
||||
"peer closed connection",
|
||||
"connection reset by peer",
|
||||
"connection was closed",
|
||||
"network connection lost",
|
||||
"unexpected eof",
|
||||
"incomplete chunked read",
|
||||
]
|
||||
|
||||
|
||||
# ── Classification pipeline ─────────────────────────────────────────────
|
||||
|
||||
def classify_api_error(
|
||||
error: Exception,
|
||||
*,
|
||||
provider: str = "",
|
||||
model: str = "",
|
||||
approx_tokens: int = 0,
|
||||
context_length: int = 200000,
|
||||
num_messages: int = 0,
|
||||
) -> ClassifiedError:
|
||||
"""Classify an API error into a structured recovery recommendation.
|
||||
|
||||
Priority-ordered pipeline:
|
||||
1. Special-case provider-specific patterns (thinking sigs, tier gates)
|
||||
2. HTTP status code + message-aware refinement
|
||||
3. Error code classification (from body)
|
||||
4. Message pattern matching (billing vs rate_limit vs context vs auth)
|
||||
5. Transport error heuristics
|
||||
6. Server disconnect + large session → context overflow
|
||||
7. Fallback: unknown (retryable with backoff)
|
||||
|
||||
Args:
|
||||
error: The exception from the API call.
|
||||
provider: Current provider name (e.g. "openrouter", "anthropic").
|
||||
model: Current model slug.
|
||||
approx_tokens: Approximate token count of the current context.
|
||||
context_length: Maximum context length for the current model.
|
||||
|
||||
Returns:
|
||||
ClassifiedError with reason and recovery action hints.
|
||||
"""
|
||||
status_code = _extract_status_code(error)
|
||||
error_type = type(error).__name__
|
||||
body = _extract_error_body(error)
|
||||
error_code = _extract_error_code(body)
|
||||
|
||||
# Build a comprehensive error message string for pattern matching.
|
||||
# str(error) alone may not include the body message (e.g. OpenAI SDK's
|
||||
# APIStatusError.__str__ returns the first arg, not the body). Append
|
||||
# the body message so patterns like "try again" in 402 disambiguation
|
||||
# are detected even when only present in the structured body.
|
||||
#
|
||||
# Also extract metadata.raw — OpenRouter wraps upstream provider errors
|
||||
# inside {"error": {"message": "Provider returned error", "metadata":
|
||||
# {"raw": "<actual error JSON>"}}} and the real error message (e.g.
|
||||
# "context length exceeded") is only in the inner JSON.
|
||||
_raw_msg = str(error).lower()
|
||||
_body_msg = ""
|
||||
_metadata_msg = ""
|
||||
if isinstance(body, dict):
|
||||
_err_obj = body.get("error", {})
|
||||
if isinstance(_err_obj, dict):
|
||||
_body_msg = (_err_obj.get("message") or "").lower()
|
||||
# Parse metadata.raw for wrapped provider errors
|
||||
_metadata = _err_obj.get("metadata", {})
|
||||
if isinstance(_metadata, dict):
|
||||
_raw_json = _metadata.get("raw") or ""
|
||||
if isinstance(_raw_json, str) and _raw_json.strip():
|
||||
try:
|
||||
import json
|
||||
_inner = json.loads(_raw_json)
|
||||
if isinstance(_inner, dict):
|
||||
_inner_err = _inner.get("error", {})
|
||||
if isinstance(_inner_err, dict):
|
||||
_metadata_msg = (_inner_err.get("message") or "").lower()
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if not _body_msg:
|
||||
_body_msg = (body.get("message") or "").lower()
|
||||
# Combine all message sources for pattern matching
|
||||
parts = [_raw_msg]
|
||||
if _body_msg and _body_msg not in _raw_msg:
|
||||
parts.append(_body_msg)
|
||||
if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg:
|
||||
parts.append(_metadata_msg)
|
||||
error_msg = " ".join(parts)
|
||||
provider_lower = (provider or "").strip().lower()
|
||||
model_lower = (model or "").strip().lower()
|
||||
|
||||
def _result(reason: FailoverReason, **overrides) -> ClassifiedError:
|
||||
defaults = {
|
||||
"reason": reason,
|
||||
"status_code": status_code,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"message": _extract_message(error, body),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ClassifiedError(**defaults)
|
||||
|
||||
# ── 1. Provider-specific patterns (highest priority) ────────────
|
||||
|
||||
# Anthropic thinking block signature invalid (400).
|
||||
# Don't gate on provider — OpenRouter proxies Anthropic errors, so the
|
||||
# provider may be "openrouter" even though the error is Anthropic-specific.
|
||||
# The message pattern ("signature" + "thinking") is unique enough.
|
||||
if (
|
||||
status_code == 400
|
||||
and "signature" in error_msg
|
||||
and "thinking" in error_msg
|
||||
):
|
||||
return _result(
|
||||
FailoverReason.thinking_signature,
|
||||
retryable=True,
|
||||
should_compress=False,
|
||||
)
|
||||
|
||||
# Anthropic long-context tier gate (429 "extra usage" + "long context")
|
||||
if (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
):
|
||||
return _result(
|
||||
FailoverReason.long_context_tier,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# ── 2. HTTP status code classification ──────────────────────────
|
||||
|
||||
if status_code is not None:
|
||||
classified = _classify_by_status(
|
||||
status_code, error_msg, error_code, body,
|
||||
provider=provider_lower, model=model_lower,
|
||||
approx_tokens=approx_tokens, context_length=context_length,
|
||||
num_messages=num_messages,
|
||||
result_fn=_result,
|
||||
)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 3. Error code classification ────────────────────────────────
|
||||
|
||||
if error_code:
|
||||
classified = _classify_by_error_code(error_code, error_msg, _result)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 4. Message pattern matching (no status code) ────────────────
|
||||
|
||||
classified = _classify_by_message(
|
||||
error_msg, error_type,
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=context_length,
|
||||
result_fn=_result,
|
||||
)
|
||||
if classified is not None:
|
||||
return classified
|
||||
|
||||
# ── 5. Server disconnect + large session → context overflow ─────
|
||||
# Must come BEFORE generic transport error catch — a disconnect on
|
||||
# a large session is more likely context overflow than a transient
|
||||
# transport hiccup. Without this ordering, RemoteProtocolError
|
||||
# always maps to timeout regardless of session size.
|
||||
|
||||
is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS)
|
||||
if is_disconnect and not status_code:
|
||||
is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200
|
||||
if is_large:
|
||||
return _result(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
return _result(FailoverReason.timeout, retryable=True)
|
||||
|
||||
# ── 6. Transport / timeout heuristics ───────────────────────────
|
||||
|
||||
if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)):
|
||||
return _result(FailoverReason.timeout, retryable=True)
|
||||
|
||||
# ── 7. Fallback: unknown ────────────────────────────────────────
|
||||
|
||||
return _result(FailoverReason.unknown, retryable=True)
|
||||
|
||||
|
||||
# ── Status code classification ──────────────────────────────────────────
|
||||
|
||||
def _classify_by_status(
|
||||
status_code: int,
|
||||
error_msg: str,
|
||||
error_code: str,
|
||||
body: dict,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
num_messages: int = 0,
|
||||
result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify based on HTTP status code with message-aware refinement."""
|
||||
|
||||
if status_code == 401:
|
||||
# Not retryable on its own — credential pool rotation and
|
||||
# provider-specific refresh (Codex, Anthropic, Nous) run before
|
||||
# the retryability check in run_agent.py. If those succeed, the
|
||||
# loop `continue`s. If they fail, retryable=False ensures we
|
||||
# hit the client-error abort path (which tries fallback first).
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 403:
|
||||
# OpenRouter 403 "key limit exceeded" is actually billing
|
||||
if "key limit exceeded" in error_msg or "spending limit" in error_msg:
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 402:
|
||||
return _classify_402(error_msg, result_fn)
|
||||
|
||||
if status_code == 404:
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
# Generic 404 — could be model or endpoint
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 413:
|
||||
return result_fn(
|
||||
FailoverReason.payload_too_large,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
if status_code == 429:
|
||||
# Already checked long_context_tier above; this is a normal rate limit
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if status_code == 400:
|
||||
return _classify_400(
|
||||
error_msg, error_code, body,
|
||||
provider=provider, model=model,
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=context_length,
|
||||
num_messages=num_messages,
|
||||
result_fn=result_fn,
|
||||
)
|
||||
|
||||
if status_code in (500, 502):
|
||||
return result_fn(FailoverReason.server_error, retryable=True)
|
||||
|
||||
if status_code in (503, 529):
|
||||
return result_fn(FailoverReason.overloaded, retryable=True)
|
||||
|
||||
# Other 4xx — non-retryable
|
||||
if 400 <= status_code < 500:
|
||||
return result_fn(
|
||||
FailoverReason.format_error,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Other 5xx — retryable
|
||||
if 500 <= status_code < 600:
|
||||
return result_fn(FailoverReason.server_error, retryable=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _classify_402(error_msg: str, result_fn) -> ClassifiedError:
|
||||
"""Disambiguate 402: billing exhaustion vs transient usage limit.
|
||||
|
||||
The key insight from OpenClaw: some 402s are transient rate limits
|
||||
disguised as payment errors. "Usage limit, try again in 5 minutes"
|
||||
is NOT a billing problem — it's a periodic quota that resets.
|
||||
"""
|
||||
# Check for transient usage-limit signals first
|
||||
has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS)
|
||||
has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS)
|
||||
|
||||
if has_usage_limit and has_transient_signal:
|
||||
# Transient quota — treat as rate limit, not billing
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Confirmed billing exhaustion
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
def _classify_400(
|
||||
error_msg: str,
|
||||
error_code: str,
|
||||
body: dict,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
num_messages: int = 0,
|
||||
result_fn,
|
||||
) -> ClassifiedError:
|
||||
"""Classify 400 Bad Request — context overflow, format error, or generic."""
|
||||
|
||||
# Context overflow from 400
|
||||
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter).
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Some providers return rate limit / billing errors as 400 instead of 429/402.
|
||||
# Check these patterns before falling through to format_error.
|
||||
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
if any(p in error_msg for p in _BILLING_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Generic 400 + large session → probable context overflow
|
||||
# Anthropic sometimes returns a bare "Error" message when context is too large
|
||||
err_body_msg = ""
|
||||
if isinstance(body, dict):
|
||||
err_obj = body.get("error", {})
|
||||
if isinstance(err_obj, dict):
|
||||
err_body_msg = (err_obj.get("message") or "").strip().lower()
|
||||
is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "")
|
||||
is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80
|
||||
|
||||
if is_generic and is_large:
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Non-retryable format error
|
||||
return result_fn(
|
||||
FailoverReason.format_error,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
# ── Error code classification ───────────────────────────────────────────
|
||||
|
||||
def _classify_by_error_code(
|
||||
error_code: str, error_msg: str, result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify by structured error codes from the response body."""
|
||||
code_lower = error_code.lower()
|
||||
|
||||
if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
)
|
||||
|
||||
if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if code_lower in ("model_not_found", "model_not_available", "invalid_model"):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
if code_lower in ("context_length_exceeded", "max_tokens_exceeded"):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── Message pattern classification ──────────────────────────────────────
|
||||
|
||||
def _classify_by_message(
|
||||
error_msg: str,
|
||||
error_type: str,
|
||||
*,
|
||||
approx_tokens: int,
|
||||
context_length: int,
|
||||
result_fn,
|
||||
) -> Optional[ClassifiedError]:
|
||||
"""Classify based on error message patterns when no status code is available."""
|
||||
|
||||
# Payload-too-large patterns (from message text when no status_code)
|
||||
if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.payload_too_large,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Billing patterns
|
||||
if any(p in error_msg for p in _BILLING_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.billing,
|
||||
retryable=False,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Rate limit patterns
|
||||
if any(p in error_msg for p in _RATE_LIMIT_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.rate_limit,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
# Context overflow patterns
|
||||
if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.context_overflow,
|
||||
retryable=True,
|
||||
should_compress=True,
|
||||
)
|
||||
|
||||
# Auth patterns
|
||||
if any(p in error_msg for p in _AUTH_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.auth,
|
||||
retryable=True,
|
||||
should_rotate_credential=True,
|
||||
)
|
||||
|
||||
# Model not found patterns
|
||||
if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS):
|
||||
return result_fn(
|
||||
FailoverReason.model_not_found,
|
||||
retryable=False,
|
||||
should_fallback=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ── Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_status_code(error: Exception) -> Optional[int]:
|
||||
"""Walk the error and its cause chain to find an HTTP status code."""
|
||||
current = error
|
||||
for _ in range(5): # Max depth to prevent infinite loops
|
||||
code = getattr(current, "status_code", None)
|
||||
if isinstance(code, int):
|
||||
return code
|
||||
# Some SDKs use .status instead of .status_code
|
||||
code = getattr(current, "status", None)
|
||||
if isinstance(code, int) and 100 <= code < 600:
|
||||
return code
|
||||
# Walk cause chain
|
||||
cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None)
|
||||
if cause is None or cause is current:
|
||||
break
|
||||
current = cause
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_body(error: Exception) -> dict:
|
||||
"""Extract the structured error body from an SDK exception."""
|
||||
body = getattr(error, "body", None)
|
||||
if isinstance(body, dict):
|
||||
return body
|
||||
# Some errors have .response.json()
|
||||
response = getattr(error, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
json_body = response.json()
|
||||
if isinstance(json_body, dict):
|
||||
return json_body
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_error_code(body: dict) -> str:
|
||||
"""Extract an error code string from the response body."""
|
||||
if not body:
|
||||
return ""
|
||||
error_obj = body.get("error", {})
|
||||
if isinstance(error_obj, dict):
|
||||
code = error_obj.get("code") or error_obj.get("type") or ""
|
||||
if isinstance(code, str) and code.strip():
|
||||
return code.strip()
|
||||
# Top-level code
|
||||
code = body.get("code") or body.get("error_code") or ""
|
||||
if isinstance(code, (str, int)):
|
||||
return str(code).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_message(error: Exception, body: dict) -> str:
|
||||
"""Extract the most informative error message."""
|
||||
# Try structured body first
|
||||
if body:
|
||||
error_obj = body.get("error", {})
|
||||
if isinstance(error_obj, dict):
|
||||
msg = error_obj.get("message", "")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
return msg.strip()[:500]
|
||||
msg = body.get("message", "")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
return msg.strip()[:500]
|
||||
# Fallback to str(error)
|
||||
return str(error)[:500]
|
||||
|
|
@ -197,6 +197,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
|||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -349,6 +349,13 @@ PLATFORM_HINTS = {
|
|||
"only — no markdown, no formatting. SMS messages are limited to ~1600 "
|
||||
"characters, so be brief and direct."
|
||||
),
|
||||
"bluebubbles": (
|
||||
"You are chatting via iMessage (BlueBubbles). iMessage does not render "
|
||||
"markdown formatting — use plain text. Keep responses concise as they "
|
||||
"appear as text messages. You can send media files natively: include "
|
||||
"MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, "
|
||||
".heic) appear as photos and other files arrive as attachments."
|
||||
),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
|
|
|
|||
242
agent/rate_limit_tracker.py
Normal file
242
agent/rate_limit_tracker.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Rate limit tracking for inference API responses.
|
||||
|
||||
Captures x-ratelimit-* headers from provider responses and provides
|
||||
formatted display for the /usage slash command. Currently supports
|
||||
the Nous Portal header format (also used by OpenRouter and OpenAI-compatible
|
||||
APIs that follow the same convention).
|
||||
|
||||
Header schema (12 headers total):
|
||||
x-ratelimit-limit-requests RPM cap
|
||||
x-ratelimit-limit-requests-1h RPH cap
|
||||
x-ratelimit-limit-tokens TPM cap
|
||||
x-ratelimit-limit-tokens-1h TPH cap
|
||||
x-ratelimit-remaining-requests requests left in minute window
|
||||
x-ratelimit-remaining-requests-1h requests left in hour window
|
||||
x-ratelimit-remaining-tokens tokens left in minute window
|
||||
x-ratelimit-remaining-tokens-1h tokens left in hour window
|
||||
x-ratelimit-reset-requests seconds until minute request window resets
|
||||
x-ratelimit-reset-requests-1h seconds until hour request window resets
|
||||
x-ratelimit-reset-tokens seconds until minute token window resets
|
||||
x-ratelimit-reset-tokens-1h seconds until hour token window resets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""One rate-limit window (e.g. requests per minute)."""
|
||||
|
||||
limit: int = 0
|
||||
remaining: int = 0
|
||||
reset_seconds: float = 0.0
|
||||
captured_at: float = 0.0 # time.time() when this was captured
|
||||
|
||||
@property
|
||||
def used(self) -> int:
|
||||
return max(0, self.limit - self.remaining)
|
||||
|
||||
@property
|
||||
def usage_pct(self) -> float:
|
||||
if self.limit <= 0:
|
||||
return 0.0
|
||||
return (self.used / self.limit) * 100.0
|
||||
|
||||
@property
|
||||
def remaining_seconds_now(self) -> float:
|
||||
"""Estimated seconds remaining until reset, adjusted for elapsed time."""
|
||||
elapsed = time.time() - self.captured_at
|
||||
return max(0.0, self.reset_seconds - elapsed)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitState:
|
||||
"""Full rate-limit state parsed from response headers."""
|
||||
|
||||
requests_min: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
requests_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
tokens_min: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
tokens_hour: RateLimitBucket = field(default_factory=RateLimitBucket)
|
||||
captured_at: float = 0.0 # when the headers were captured
|
||||
provider: str = ""
|
||||
|
||||
@property
|
||||
def has_data(self) -> bool:
|
||||
return self.captured_at > 0
|
||||
|
||||
@property
|
||||
def age_seconds(self) -> float:
|
||||
if not self.has_data:
|
||||
return float("inf")
|
||||
return time.time() - self.captured_at
|
||||
|
||||
|
||||
def _safe_int(value: Any, default: int = 0) -> int:
|
||||
try:
|
||||
return int(float(value))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def parse_rate_limit_headers(
|
||||
headers: Mapping[str, str],
|
||||
provider: str = "",
|
||||
) -> Optional[RateLimitState]:
|
||||
"""Parse x-ratelimit-* headers into a RateLimitState.
|
||||
|
||||
Returns None if no rate limit headers are present.
|
||||
"""
|
||||
# Quick check: at least one rate limit header must exist
|
||||
has_any = any(k.lower().startswith("x-ratelimit-") for k in headers)
|
||||
if not has_any:
|
||||
return None
|
||||
|
||||
now = time.time()
|
||||
|
||||
def _bucket(resource: str, suffix: str = "") -> RateLimitBucket:
|
||||
# e.g. resource="requests", suffix="" -> per-minute
|
||||
# resource="tokens", suffix="-1h" -> per-hour
|
||||
tag = f"{resource}{suffix}"
|
||||
return RateLimitBucket(
|
||||
limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")),
|
||||
remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")),
|
||||
reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")),
|
||||
captured_at=now,
|
||||
)
|
||||
|
||||
return RateLimitState(
|
||||
requests_min=_bucket("requests"),
|
||||
requests_hour=_bucket("requests", "-1h"),
|
||||
tokens_min=_bucket("tokens"),
|
||||
tokens_hour=_bucket("tokens", "-1h"),
|
||||
captured_at=now,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
# ── Formatting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _fmt_count(n: int) -> str:
|
||||
"""Human-friendly number: 7999856 -> '8.0M', 33599 -> '33.6K', 799 -> '799'."""
|
||||
if n >= 1_000_000:
|
||||
return f"{n / 1_000_000:.1f}M"
|
||||
if n >= 10_000:
|
||||
return f"{n / 1_000:.1f}K"
|
||||
if n >= 1_000:
|
||||
return f"{n / 1_000:.1f}K"
|
||||
return str(n)
|
||||
|
||||
|
||||
def _fmt_seconds(seconds: float) -> str:
|
||||
"""Seconds -> human-friendly duration: '58s', '2m 14s', '58m 57s', '1h 2m'."""
|
||||
s = max(0, int(seconds))
|
||||
if s < 60:
|
||||
return f"{s}s"
|
||||
if s < 3600:
|
||||
m, sec = divmod(s, 60)
|
||||
return f"{m}m {sec}s" if sec else f"{m}m"
|
||||
h, remainder = divmod(s, 3600)
|
||||
m = remainder // 60
|
||||
return f"{h}h {m}m" if m else f"{h}h"
|
||||
|
||||
|
||||
def _bar(pct: float, width: int = 20) -> str:
|
||||
"""ASCII progress bar: [████████░░░░░░░░░░░░] 40%."""
|
||||
filled = int(pct / 100.0 * width)
|
||||
filled = max(0, min(width, filled))
|
||||
empty = width - filled
|
||||
return f"[{'█' * filled}{'░' * empty}]"
|
||||
|
||||
|
||||
def _bucket_line(label: str, bucket: RateLimitBucket, label_width: int = 14) -> str:
|
||||
"""Format one bucket as a single line."""
|
||||
if bucket.limit <= 0:
|
||||
return f" {label:<{label_width}} (no data)"
|
||||
|
||||
pct = bucket.usage_pct
|
||||
used = _fmt_count(bucket.used)
|
||||
limit = _fmt_count(bucket.limit)
|
||||
remaining = _fmt_count(bucket.remaining)
|
||||
reset = _fmt_seconds(bucket.remaining_seconds_now)
|
||||
|
||||
bar = _bar(pct)
|
||||
return f" {label:<{label_width}} {bar} {pct:5.1f}% {used}/{limit} used ({remaining} left, resets in {reset})"
|
||||
|
||||
|
||||
def format_rate_limit_display(state: RateLimitState) -> str:
|
||||
"""Format rate limit state for terminal/chat display."""
|
||||
if not state.has_data:
|
||||
return "No rate limit data yet — make an API request first."
|
||||
|
||||
age = state.age_seconds
|
||||
if age < 5:
|
||||
freshness = "just now"
|
||||
elif age < 60:
|
||||
freshness = f"{int(age)}s ago"
|
||||
else:
|
||||
freshness = f"{_fmt_seconds(age)} ago"
|
||||
|
||||
provider_label = state.provider.title() if state.provider else "Provider"
|
||||
|
||||
lines = [
|
||||
f"{provider_label} Rate Limits (captured {freshness}):",
|
||||
"",
|
||||
_bucket_line("Requests/min", state.requests_min),
|
||||
_bucket_line("Requests/hr", state.requests_hour),
|
||||
"",
|
||||
_bucket_line("Tokens/min", state.tokens_min),
|
||||
_bucket_line("Tokens/hr", state.tokens_hour),
|
||||
]
|
||||
|
||||
# Add warnings if any bucket is getting hot
|
||||
warnings = []
|
||||
for label, bucket in [
|
||||
("requests/min", state.requests_min),
|
||||
("requests/hr", state.requests_hour),
|
||||
("tokens/min", state.tokens_min),
|
||||
("tokens/hr", state.tokens_hour),
|
||||
]:
|
||||
if bucket.limit > 0 and bucket.usage_pct >= 80:
|
||||
reset = _fmt_seconds(bucket.remaining_seconds_now)
|
||||
warnings.append(f" ⚠ {label} at {bucket.usage_pct:.0f}% — resets in {reset}")
|
||||
|
||||
if warnings:
|
||||
lines.append("")
|
||||
lines.extend(warnings)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_rate_limit_compact(state: RateLimitState) -> str:
|
||||
"""One-line compact summary for status bars / gateway messages."""
|
||||
if not state.has_data:
|
||||
return "No rate limit data."
|
||||
|
||||
rm = state.requests_min
|
||||
tm = state.tokens_min
|
||||
rh = state.requests_hour
|
||||
th = state.tokens_hour
|
||||
|
||||
parts = []
|
||||
if rm.limit > 0:
|
||||
parts.append(f"RPM: {rm.remaining}/{rm.limit}")
|
||||
if rh.limit > 0:
|
||||
parts.append(f"RPH: {_fmt_count(rh.remaining)}/{_fmt_count(rh.limit)} (resets {_fmt_seconds(rh.remaining_seconds_now)})")
|
||||
if tm.limit > 0:
|
||||
parts.append(f"TPM: {_fmt_count(tm.remaining)}/{_fmt_count(tm.limit)}")
|
||||
if th.limit > 0:
|
||||
parts.append(f"TPH: {_fmt_count(th.remaining)}/{_fmt_count(th.limit)} (resets {_fmt_seconds(th.remaining_seconds_now)})")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
|
@ -159,7 +159,10 @@ class SubdirectoryHintTracker:
|
|||
|
||||
def _is_valid_subdir(self, path: Path) -> bool:
|
||||
"""Check if path is a valid directory to scan for hints."""
|
||||
if not path.is_dir():
|
||||
try:
|
||||
if not path.is_dir():
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
if path in self._loaded_dirs:
|
||||
return False
|
||||
|
|
@ -172,7 +175,10 @@ class SubdirectoryHintTracker:
|
|||
found_hints = []
|
||||
for filename in _HINT_FILENAMES:
|
||||
hint_path = directory / filename
|
||||
if not hint_path.is_file():
|
||||
try:
|
||||
if not hint_path.is_file():
|
||||
continue
|
||||
except OSError:
|
||||
continue
|
||||
try:
|
||||
content = hint_path.read_text(encoding="utf-8").strip()
|
||||
|
|
|
|||
|
|
@ -117,7 +117,8 @@ terminal:
|
|||
timeout: 180
|
||||
docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace.
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
# sudo_password: "hunter2" # Optional: pipe a sudo password via sudo -S. SECURITY WARNING: plaintext.
|
||||
# sudo_password: "" # Explicit empty password: try empty and never open the interactive sudo prompt.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: SSH remote execution
|
||||
|
|
@ -208,13 +209,18 @@ terminal:
|
|||
#
|
||||
# SECURITY WARNING: Password stored in plaintext!
|
||||
#
|
||||
# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running,
|
||||
# INTERACTIVE PROMPT: If sudo_password is unset and the CLI is running,
|
||||
# you'll be prompted to enter your password when sudo is needed:
|
||||
# - 45-second timeout (auto-skips if no input)
|
||||
# - Press Enter to skip (command fails gracefully)
|
||||
# - Password is hidden while typing
|
||||
# - Password is cached for the session
|
||||
#
|
||||
# EMPTY PASSWORDS: Setting sudo_password to an explicit empty string is different
|
||||
# from leaving it unset. Hermes will try an empty password via `sudo -S` and
|
||||
# will not open the interactive prompt. This is useful for passwordless sudo,
|
||||
# Touch ID sudo setups, and environments where prompting is just noise.
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - SSH backend: Configure passwordless sudo on the remote server
|
||||
# - Containers: Run as root inside the container (no sudo needed)
|
||||
|
|
@ -445,6 +451,16 @@ agent:
|
|||
# Higher = more room for complex tasks, but costs more tokens
|
||||
# Recommended: 20-30 for focused tasks, 50-100 for open exploration
|
||||
max_turns: 60
|
||||
|
||||
# Inactivity timeout for gateway agent runs (seconds, 0 = unlimited).
|
||||
# The agent can run indefinitely when actively calling tools or receiving
|
||||
# API responses. Only fires after the agent has been idle for this duration.
|
||||
# gateway_timeout: 1800
|
||||
|
||||
# Staged warning: send a warning before escalating to full timeout.
|
||||
# Fires once per run when inactivity reaches this threshold (seconds).
|
||||
# Set to 0 to disable the warning.
|
||||
# gateway_timeout_warning: 900
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
|
|
|||
57
cli.py
57
cli.py
|
|
@ -1546,6 +1546,7 @@ class HermesCLI:
|
|||
self._clarify_deadline = 0
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._modal_input_snapshot = None
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._approval_lock = threading.Lock()
|
||||
|
|
@ -5408,12 +5409,27 @@ class HermesCLI:
|
|||
print(f" ❌ Compression failed: {e}")
|
||||
|
||||
def _show_usage(self):
|
||||
"""Show cumulative token usage for the current session."""
|
||||
"""Show rate limits (if available) and session token usage."""
|
||||
if not self.agent:
|
||||
print("(._.) No active agent -- send a message first.")
|
||||
return
|
||||
|
||||
agent = self.agent
|
||||
calls = agent.session_api_calls
|
||||
|
||||
if calls == 0:
|
||||
print("(._.) No API calls made yet in this session.")
|
||||
return
|
||||
|
||||
# ── Rate limits (shown first when available) ────────────────
|
||||
rl_state = agent.get_rate_limit_state()
|
||||
if rl_state and rl_state.has_data:
|
||||
from agent.rate_limit_tracker import format_rate_limit_display
|
||||
print()
|
||||
print(format_rate_limit_display(rl_state))
|
||||
print()
|
||||
|
||||
# ── Session token usage ─────────────────────────────────────
|
||||
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
|
||||
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
|
||||
cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0
|
||||
|
|
@ -5421,13 +5437,7 @@ class HermesCLI:
|
|||
prompt = agent.session_prompt_tokens
|
||||
completion = agent.session_completion_tokens
|
||||
total = agent.session_total_tokens
|
||||
calls = agent.session_api_calls
|
||||
|
||||
if calls == 0:
|
||||
print("(._.) No API calls made yet in this session.")
|
||||
return
|
||||
|
||||
# Current context window state
|
||||
compressor = agent.context_compressor
|
||||
last_prompt = compressor.last_prompt_tokens
|
||||
ctx_len = compressor.context_length
|
||||
|
|
@ -6205,6 +6215,7 @@ class HermesCLI:
|
|||
timeout = 45
|
||||
response_queue = queue.Queue()
|
||||
|
||||
self._capture_modal_input_snapshot()
|
||||
self._sudo_state = {
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
|
|
@ -6217,6 +6228,7 @@ class HermesCLI:
|
|||
result = response_queue.get(timeout=1)
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate()
|
||||
if result:
|
||||
_cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
|
||||
|
|
@ -6231,6 +6243,7 @@ class HermesCLI:
|
|||
|
||||
self._sudo_state = None
|
||||
self._sudo_deadline = 0
|
||||
self._restore_modal_input_snapshot()
|
||||
self._invalidate()
|
||||
_cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
|
||||
return ""
|
||||
|
|
@ -6403,6 +6416,33 @@ class HermesCLI:
|
|||
def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict:
|
||||
return prompt_for_secret(self, var_name, prompt, metadata)
|
||||
|
||||
def _capture_modal_input_snapshot(self) -> None:
|
||||
"""Temporarily clear the input buffer and save the user's in-progress draft."""
|
||||
if self._modal_input_snapshot is not None or not getattr(self, "_app", None):
|
||||
return
|
||||
try:
|
||||
buf = self._app.current_buffer
|
||||
self._modal_input_snapshot = {
|
||||
"text": buf.text,
|
||||
"cursor_position": buf.cursor_position,
|
||||
}
|
||||
buf.reset()
|
||||
except Exception:
|
||||
self._modal_input_snapshot = None
|
||||
|
||||
def _restore_modal_input_snapshot(self) -> None:
|
||||
"""Restore any draft text that was present before a modal prompt opened."""
|
||||
snapshot = self._modal_input_snapshot
|
||||
self._modal_input_snapshot = None
|
||||
if not snapshot or not getattr(self, "_app", None):
|
||||
return
|
||||
try:
|
||||
buf = self._app.current_buffer
|
||||
buf.text = snapshot.get("text", "")
|
||||
buf.cursor_position = min(snapshot.get("cursor_position", 0), len(buf.text))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _submit_secret_response(self, value: str) -> None:
|
||||
if not self._secret_state:
|
||||
return
|
||||
|
|
@ -7130,6 +7170,7 @@ class HermesCLI:
|
|||
# Sudo password prompt state (similar mechanism to clarify)
|
||||
self._sudo_state = None # dict with response_queue when active
|
||||
self._sudo_deadline = 0
|
||||
self._modal_input_snapshot = None
|
||||
|
||||
# Dangerous command approval state (similar mechanism to clarify)
|
||||
self._approval_state = None # dict with command, description, choices, selected, response_queue
|
||||
|
|
@ -7201,7 +7242,6 @@ class HermesCLI:
|
|||
text = event.app.current_buffer.text
|
||||
self._sudo_state["response_queue"].put(text)
|
||||
self._sudo_state = None
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
|
|
@ -7406,7 +7446,6 @@ class HermesCLI:
|
|||
if self._sudo_state:
|
||||
self._sudo_state["response_queue"].put("")
|
||||
self._sudo_state = None
|
||||
event.app.current_buffer.reset()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
|
|||
_KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "sms", "email", "webhook",
|
||||
"wecom", "sms", "email", "webhook", "bluebubbles",
|
||||
})
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
|
@ -91,7 +91,7 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
|||
}
|
||||
# Origin missing (e.g. job created via API/script) — try each
|
||||
# platform's home channel as a fallback instead of silently dropping.
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack"):
|
||||
for platform_name in ("matrix", "telegram", "discord", "slack", "bluebubbles"):
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if chat_id:
|
||||
logger.info(
|
||||
|
|
@ -236,6 +236,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
|||
"wecom": Platform.WECOM,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
|
|
|
|||
8
flake.lock
generated
8
flake.lock
generated
|
|
@ -22,16 +22,16 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1751274312,
|
||||
"narHash": "sha256-/bVBlRpECLVzjV19t5KMdMFWSwKLtb5RyXdjz3LJT+g=",
|
||||
"lastModified": 1775036866,
|
||||
"narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "50ab793786d9de88ee30ec4e4c24fb4236fc2674",
|
||||
"rev": "6201e203d09599479a3b3450ed24fa81537ebc4e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-24.11",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
description = "Hermes Agent - AI agent framework by Nous Research";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
|||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms"):
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ class Platform(Enum):
|
|||
WEBHOOK = "webhook"
|
||||
FEISHU = "feishu"
|
||||
WECOM = "wecom"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -287,6 +288,9 @@ class GatewayConfig:
|
|||
# WeCom uses extra dict for bot credentials
|
||||
elif platform == Platform.WECOM and config.extra.get("bot_id"):
|
||||
connected.append(platform)
|
||||
# BlueBubbles uses extra dict for local server config
|
||||
elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
|
|
@ -948,6 +952,29 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
|||
name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# BlueBubbles (iMessage)
|
||||
bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL")
|
||||
bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD")
|
||||
if bluebubbles_server_url and bluebubbles_password:
|
||||
if Platform.BLUEBUBBLES not in config.platforms:
|
||||
config.platforms[Platform.BLUEBUBBLES] = PlatformConfig()
|
||||
config.platforms[Platform.BLUEBUBBLES].enabled = True
|
||||
config.platforms[Platform.BLUEBUBBLES].extra.update({
|
||||
"server_url": bluebubbles_server_url.rstrip("/"),
|
||||
"password": bluebubbles_password,
|
||||
"webhook_host": os.getenv("BLUEBUBBLES_WEBHOOK_HOST", "127.0.0.1"),
|
||||
"webhook_port": int(os.getenv("BLUEBUBBLES_WEBHOOK_PORT", "8645")),
|
||||
"webhook_path": os.getenv("BLUEBUBBLES_WEBHOOK_PATH", "/bluebubbles-webhook"),
|
||||
"send_read_receipts": os.getenv("BLUEBUBBLES_SEND_READ_RECEIPTS", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
bluebubbles_home = os.getenv("BLUEBUBBLES_HOME_CHANNEL")
|
||||
if bluebubbles_home and Platform.BLUEBUBBLES in config.platforms:
|
||||
config.platforms[Platform.BLUEBUBBLES].home_channel = HomeChannel(
|
||||
platform=Platform.BLUEBUBBLES,
|
||||
chat_id=bluebubbles_home,
|
||||
name=os.getenv("BLUEBUBBLES_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ SUPPORTED_DOCUMENT_TYPES = {
|
|||
".pdf": "application/pdf",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain",
|
||||
".log": "text/plain",
|
||||
".zip": "application/zip",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
|
|
@ -407,6 +408,10 @@ class MessageEvent:
|
|||
# Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics)
|
||||
auto_skill: Optional[str] = None
|
||||
|
||||
# Internal flag — set for synthetic events (e.g. background process
|
||||
# completion notifications) that must bypass user authorization checks.
|
||||
internal: bool = False
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
|
|
|||
828
gateway/platforms/bluebubbles.py
Normal file
828
gateway/platforms/bluebubbles.py
Normal file
|
|
@ -0,0 +1,828 @@
|
|||
"""BlueBubbles iMessage platform adapter.
|
||||
|
||||
Uses the local BlueBubbles macOS server for outbound REST sends and inbound
|
||||
webhooks. Supports text messaging, media attachments (images, voice, video,
|
||||
documents), tapback reactions, typing indicators, and read receipts.
|
||||
|
||||
Architecture based on PR #5869 (benjaminsehl) with inbound attachment
|
||||
downloading from PR #4588 (YuhangLin).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_WEBHOOK_HOST = "127.0.0.1"
|
||||
DEFAULT_WEBHOOK_PORT = 8645
|
||||
DEFAULT_WEBHOOK_PATH = "/bluebubbles-webhook"
|
||||
MAX_TEXT_LENGTH = 4000
|
||||
|
||||
# Tapback reaction codes (BlueBubbles associatedMessageType values)
|
||||
_TAPBACK_ADDED = {
|
||||
2000: "love", 2001: "like", 2002: "dislike",
|
||||
2003: "laugh", 2004: "emphasize", 2005: "question",
|
||||
}
|
||||
_TAPBACK_REMOVED = {
|
||||
3000: "love", 3001: "like", 3002: "dislike",
|
||||
3003: "laugh", 3004: "emphasize", 3005: "question",
|
||||
}
|
||||
|
||||
# Webhook event types that carry user messages
|
||||
_MESSAGE_EVENTS = {"new-message", "message", "updated-message"}
|
||||
|
||||
# Log redaction patterns
|
||||
_PHONE_RE = re.compile(r"\+?\d{7,15}")
|
||||
_EMAIL_RE = re.compile(r"[\w.+-]+@[\w-]+\.[\w.]+")
|
||||
|
||||
|
||||
def _redact(text: str) -> str:
|
||||
"""Redact phone numbers and emails from log output."""
|
||||
text = _PHONE_RE.sub("[REDACTED]", text)
|
||||
text = _EMAIL_RE.sub("[REDACTED]", text)
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_bluebubbles_requirements() -> bool:
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
import httpx as _httpx # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_server_url(raw: str) -> str:
|
||||
value = (raw or "").strip()
|
||||
if not value:
|
||||
return ""
|
||||
if not re.match(r"^https?://", value, flags=re.I):
|
||||
value = f"http://{value}"
|
||||
return value.rstrip("/")
|
||||
|
||||
|
||||
def _strip_markdown(text: str) -> str:
|
||||
"""Strip common markdown formatting for iMessage plain-text delivery."""
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BlueBubblesAdapter(BasePlatformAdapter):
|
||||
platform = Platform.BLUEBUBBLES
|
||||
MAX_MESSAGE_LENGTH = MAX_TEXT_LENGTH
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.BLUEBUBBLES)
|
||||
extra = config.extra or {}
|
||||
self.server_url = _normalize_server_url(
|
||||
extra.get("server_url") or os.getenv("BLUEBUBBLES_SERVER_URL", "")
|
||||
)
|
||||
self.password = extra.get("password") or os.getenv("BLUEBUBBLES_PASSWORD", "")
|
||||
self.webhook_host = (
|
||||
extra.get("webhook_host")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
|
||||
)
|
||||
self.webhook_port = int(
|
||||
extra.get("webhook_port")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self.webhook_path = (
|
||||
extra.get("webhook_path")
|
||||
or os.getenv("BLUEBUBBLES_WEBHOOK_PATH", DEFAULT_WEBHOOK_PATH)
|
||||
)
|
||||
if not str(self.webhook_path).startswith("/"):
|
||||
self.webhook_path = f"/{self.webhook_path}"
|
||||
self.send_read_receipts = bool(extra.get("send_read_receipts", True))
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
self._runner = None
|
||||
self._private_api_enabled: Optional[bool] = None
|
||||
self._helper_connected: bool = False
|
||||
self._guid_cache: Dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _api_url(self, path: str) -> str:
|
||||
sep = "&" if "?" in path else "?"
|
||||
return f"{self.server_url}{path}{sep}password={quote(self.password, safe='')}"
|
||||
|
||||
async def _api_get(self, path: str) -> Dict[str, Any]:
|
||||
assert self.client is not None
|
||||
res = await self.client.get(self._api_url(path))
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
async def _api_post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
assert self.client is not None
|
||||
res = await self.client.post(self._api_url(path), json=payload)
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
if not self.server_url or not self.password:
|
||||
logger.error(
|
||||
"[bluebubbles] BLUEBUBBLES_SERVER_URL and BLUEBUBBLES_PASSWORD are required"
|
||||
)
|
||||
return False
|
||||
from aiohttp import web
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
try:
|
||||
await self._api_get("/api/v1/ping")
|
||||
info = await self._api_get("/api/v1/server/info")
|
||||
server_data = (info or {}).get("data", {})
|
||||
self._private_api_enabled = bool(server_data.get("private_api"))
|
||||
self._helper_connected = bool(server_data.get("helper_connected"))
|
||||
logger.info(
|
||||
"[bluebubbles] connected to %s (private_api=%s, helper=%s)",
|
||||
self.server_url,
|
||||
self._private_api_enabled,
|
||||
self._helper_connected,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[bluebubbles] cannot reach server at %s: %s", self.server_url, exc
|
||||
)
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
return False
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
app.router.add_post(self.webhook_path, self._handle_webhook)
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, self.webhook_host, self.webhook_port)
|
||||
await site.start()
|
||||
self._mark_connected()
|
||||
logger.info(
|
||||
"[bluebubbles] webhook listening on http://%s:%s%s",
|
||||
self.webhook_host,
|
||||
self.webhook_port,
|
||||
self.webhook_path,
|
||||
)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._mark_disconnected()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat GUID resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _resolve_chat_guid(self, target: str) -> Optional[str]:
|
||||
"""Resolve an email/phone to a BlueBubbles chat GUID.
|
||||
|
||||
If *target* already contains a semicolon (raw GUID format like
|
||||
``iMessage;-;user@example.com``), it is returned as-is. Otherwise
|
||||
the adapter queries the BlueBubbles chat list and matches on
|
||||
``chatIdentifier`` or participant address.
|
||||
"""
|
||||
target = (target or "").strip()
|
||||
if not target:
|
||||
return None
|
||||
# Already a raw GUID
|
||||
if ";" in target:
|
||||
return target
|
||||
if target in self._guid_cache:
|
||||
return self._guid_cache[target]
|
||||
try:
|
||||
payload = await self._api_post(
|
||||
"/api/v1/chat/query",
|
||||
{"limit": 100, "offset": 0, "with": ["participants"]},
|
||||
)
|
||||
for chat in payload.get("data", []) or []:
|
||||
guid = chat.get("guid") or chat.get("chatGuid")
|
||||
identifier = chat.get("chatIdentifier") or chat.get("identifier")
|
||||
if identifier == target:
|
||||
if guid:
|
||||
self._guid_cache[target] = guid
|
||||
return guid
|
||||
for part in chat.get("participants", []) or []:
|
||||
if (part.get("address") or "").strip() == target and guid:
|
||||
self._guid_cache[target] = guid
|
||||
return guid
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _create_chat_for_handle(
|
||||
self, address: str, message: str
|
||||
) -> SendResult:
|
||||
"""Create a new chat by sending the first message to *address*."""
|
||||
payload = {
|
||||
"addresses": [address],
|
||||
"message": message,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
}
|
||||
try:
|
||||
res = await self._api_post("/api/v1/chat/new", payload)
|
||||
data = res.get("data") or {}
|
||||
msg_id = data.get("guid") or data.get("messageGuid") or "ok"
|
||||
return SendResult(success=True, message_id=str(msg_id), raw_response=res)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text sending
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
text = _strip_markdown(content or "")
|
||||
if not text:
|
||||
return SendResult(success=False, error="BlueBubbles send requires text")
|
||||
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
|
||||
last = SendResult(success=True)
|
||||
for chunk in chunks:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
# If the target looks like an address, try creating a new chat
|
||||
if self._private_api_enabled and (
|
||||
"@" in chat_id or re.match(r"^\+\d+", chat_id)
|
||||
):
|
||||
return await self._create_chat_for_handle(chat_id, chunk)
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=f"BlueBubbles chat not found for target: {chat_id}",
|
||||
)
|
||||
payload: Dict[str, Any] = {
|
||||
"chatGuid": guid,
|
||||
"tempGuid": f"temp-{datetime.utcnow().timestamp()}",
|
||||
"message": chunk,
|
||||
}
|
||||
if reply_to and self._private_api_enabled and self._helper_connected:
|
||||
payload["method"] = "private-api"
|
||||
payload["selectedMessageGuid"] = reply_to
|
||||
payload["partIndex"] = 0
|
||||
try:
|
||||
res = await self._api_post("/api/v1/message/text", payload)
|
||||
data = res.get("data") or {}
|
||||
msg_id = data.get("guid") or data.get("messageGuid") or "ok"
|
||||
last = SendResult(
|
||||
success=True, message_id=str(msg_id), raw_response=res
|
||||
)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
return last
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Media sending (outbound)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_attachment(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
filename: Optional[str] = None,
|
||||
caption: Optional[str] = None,
|
||||
is_audio_message: bool = False,
|
||||
) -> SendResult:
|
||||
"""Send a file attachment via BlueBubbles multipart upload."""
|
||||
if not self.client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
if not os.path.isfile(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
return SendResult(success=False, error=f"Chat not found: {chat_id}")
|
||||
|
||||
fname = filename or os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"attachment": (fname, f, "application/octet-stream")}
|
||||
data: Dict[str, str] = {
|
||||
"chatGuid": guid,
|
||||
"name": fname,
|
||||
"tempGuid": uuid.uuid4().hex,
|
||||
}
|
||||
if is_audio_message:
|
||||
data["isAudioMessage"] = "true"
|
||||
res = await self.client.post(
|
||||
self._api_url("/api/v1/message/attachment"),
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=120,
|
||||
)
|
||||
res.raise_for_status()
|
||||
result = res.json()
|
||||
|
||||
if caption:
|
||||
await self.send(chat_id, caption)
|
||||
|
||||
if result.get("status") == 200:
|
||||
rdata = result.get("data") or {}
|
||||
msg_id = rdata.get("guid") if isinstance(rdata, dict) else None
|
||||
return SendResult(
|
||||
success=True, message_id=msg_id, raw_response=result
|
||||
)
|
||||
return SendResult(
|
||||
success=False,
|
||||
error=result.get("message", "Attachment upload failed"),
|
||||
)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
try:
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
|
||||
local_path = await cache_image_from_url(image_url)
|
||||
return await self._send_attachment(chat_id, local_path, caption=caption)
|
||||
except Exception:
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(chat_id, image_path, caption=caption)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(
|
||||
chat_id, audio_path, caption=caption, is_audio_message=True
|
||||
)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(chat_id, video_path, caption=caption)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
return await self._send_attachment(
|
||||
chat_id, file_path, filename=file_name, caption=caption
|
||||
)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
return await self.send_image(
|
||||
chat_id, animation_url, caption, reply_to, metadata
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typing indicators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.post(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.delete(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/typing"), timeout=5
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read receipts
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def mark_read(self, chat_id: str) -> bool:
|
||||
if not self._private_api_enabled or not self._helper_connected or not self.client:
|
||||
return False
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
await self.client.post(
|
||||
self._api_url(f"/api/v1/chat/{encoded}/read"), timeout=5
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tapback reactions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_reaction(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_guid: str,
|
||||
reaction: str,
|
||||
part_index: int = 0,
|
||||
) -> SendResult:
|
||||
"""Send a tapback reaction (requires Private API helper)."""
|
||||
if not self._private_api_enabled or not self._helper_connected:
|
||||
return SendResult(
|
||||
success=False, error="Private API helper not connected"
|
||||
)
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if not guid:
|
||||
return SendResult(success=False, error=f"Chat not found: {chat_id}")
|
||||
try:
|
||||
res = await self._api_post(
|
||||
"/api/v1/message/react",
|
||||
{
|
||||
"chatGuid": guid,
|
||||
"selectedMessageGuid": message_guid,
|
||||
"reaction": reaction,
|
||||
"partIndex": part_index,
|
||||
},
|
||||
)
|
||||
return SendResult(success=True, raw_response=res)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
is_group = ";+;" in (chat_id or "")
|
||||
info: Dict[str, Any] = {
|
||||
"name": chat_id,
|
||||
"type": "group" if is_group else "dm",
|
||||
}
|
||||
try:
|
||||
guid = await self._resolve_chat_guid(chat_id)
|
||||
if guid:
|
||||
encoded = quote(guid, safe="")
|
||||
res = await self._api_get(
|
||||
f"/api/v1/chat/{encoded}?with=participants"
|
||||
)
|
||||
data = (res or {}).get("data", {})
|
||||
display_name = (
|
||||
data.get("displayName")
|
||||
or data.get("chatIdentifier")
|
||||
or chat_id
|
||||
)
|
||||
participants = []
|
||||
for p in data.get("participants", []) or []:
|
||||
addr = (p.get("address") or "").strip()
|
||||
if addr:
|
||||
participants.append(addr)
|
||||
info["name"] = display_name
|
||||
if participants:
|
||||
info["participants"] = participants
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
return _strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound attachment downloading (from #4588)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _download_attachment(
|
||||
self, att_guid: str, att_meta: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Download an attachment from BlueBubbles and cache it locally.
|
||||
|
||||
Returns the local file path on success, None on failure.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
encoded = quote(att_guid, safe="")
|
||||
resp = await self.client.get(
|
||||
self._api_url(f"/api/v1/attachment/{encoded}/download"),
|
||||
timeout=60,
|
||||
follow_redirects=True,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
|
||||
mime = (att_meta.get("mimeType") or "").lower()
|
||||
transfer_name = att_meta.get("transferName", "")
|
||||
|
||||
if mime.startswith("image/"):
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/heic": ".jpg",
|
||||
"image/heif": ".jpg",
|
||||
"image/tiff": ".jpg",
|
||||
}
|
||||
ext = ext_map.get(mime, ".jpg")
|
||||
return cache_image_from_bytes(data, ext)
|
||||
|
||||
if mime.startswith("audio/"):
|
||||
ext_map = {
|
||||
"audio/mp3": ".mp3",
|
||||
"audio/mpeg": ".mp3",
|
||||
"audio/ogg": ".ogg",
|
||||
"audio/wav": ".wav",
|
||||
"audio/x-caf": ".mp3",
|
||||
"audio/mp4": ".m4a",
|
||||
"audio/aac": ".m4a",
|
||||
}
|
||||
ext = ext_map.get(mime, ".mp3")
|
||||
return cache_audio_from_bytes(data, ext)
|
||||
|
||||
# Videos, documents, and everything else
|
||||
filename = transfer_name or f"file_{uuid.uuid4().hex[:8]}"
|
||||
return cache_document_from_bytes(data, filename)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[bluebubbles] failed to download attachment %s: %s",
|
||||
_redact(att_guid),
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Webhook handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_payload_record(
|
||||
self, payload: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
if isinstance(payload.get("message"), dict):
|
||||
return payload.get("message")
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _value(*candidates: Any) -> Optional[str]:
|
||||
for candidate in candidates:
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
return None
|
||||
|
||||
async def _handle_webhook(self, request):
|
||||
from aiohttp import web
|
||||
|
||||
token = (
|
||||
request.query.get("password")
|
||||
or request.query.get("guid")
|
||||
or request.headers.get("x-password")
|
||||
or request.headers.get("x-guid")
|
||||
or request.headers.get("x-bluebubbles-guid")
|
||||
)
|
||||
if token != self.password:
|
||||
return web.json_response({"error": "unauthorized"}, status=401)
|
||||
try:
|
||||
raw = await request.read()
|
||||
body = raw.decode("utf-8", errors="replace")
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except Exception:
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
form = parse_qs(body)
|
||||
payload_str = (
|
||||
form.get("payload")
|
||||
or form.get("data")
|
||||
or form.get("message")
|
||||
or [""]
|
||||
)[0]
|
||||
payload = json.loads(payload_str) if payload_str else {}
|
||||
except Exception as exc:
|
||||
logger.error("[bluebubbles] webhook parse error: %s", exc)
|
||||
return web.json_response({"error": "invalid payload"}, status=400)
|
||||
|
||||
event_type = self._value(payload.get("type"), payload.get("event")) or ""
|
||||
# Only process message events; silently acknowledge everything else
|
||||
if event_type and event_type not in _MESSAGE_EVENTS:
|
||||
return web.Response(text="ok")
|
||||
|
||||
record = self._extract_payload_record(payload) or {}
|
||||
is_from_me = bool(
|
||||
record.get("isFromMe")
|
||||
or record.get("fromMe")
|
||||
or record.get("is_from_me")
|
||||
)
|
||||
if is_from_me:
|
||||
return web.Response(text="ok")
|
||||
|
||||
# Skip tapback reactions delivered as messages
|
||||
assoc_type = record.get("associatedMessageType")
|
||||
if isinstance(assoc_type, int) and assoc_type in {
|
||||
**_TAPBACK_ADDED,
|
||||
**_TAPBACK_REMOVED,
|
||||
}:
|
||||
return web.Response(text="ok")
|
||||
|
||||
text = (
|
||||
self._value(
|
||||
record.get("text"), record.get("message"), record.get("body")
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
# --- Inbound attachment handling ---
|
||||
attachments = record.get("attachments") or []
|
||||
media_urls: List[str] = []
|
||||
media_types: List[str] = []
|
||||
msg_type = MessageType.TEXT
|
||||
|
||||
for att in attachments:
|
||||
att_guid = att.get("guid", "")
|
||||
if not att_guid:
|
||||
continue
|
||||
cached = await self._download_attachment(att_guid, att)
|
||||
if cached:
|
||||
mime = (att.get("mimeType") or "").lower()
|
||||
media_urls.append(cached)
|
||||
media_types.append(mime)
|
||||
if mime.startswith("image/"):
|
||||
msg_type = MessageType.PHOTO
|
||||
elif mime.startswith("audio/") or (att.get("uti") or "").endswith(
|
||||
"caf"
|
||||
):
|
||||
msg_type = MessageType.VOICE
|
||||
elif mime.startswith("video/"):
|
||||
msg_type = MessageType.VIDEO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
# With multiple attachments, prefer PHOTO if any images present
|
||||
if len(media_urls) > 1:
|
||||
mime_prefixes = {(m or "").split("/")[0] for m in media_types}
|
||||
if "image" in mime_prefixes:
|
||||
msg_type = MessageType.PHOTO
|
||||
|
||||
if not text and media_urls:
|
||||
text = "(attachment)"
|
||||
# --- End attachment handling ---
|
||||
|
||||
chat_guid = self._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
chat_identifier = self._value(
|
||||
record.get("chatIdentifier"),
|
||||
record.get("identifier"),
|
||||
payload.get("chatIdentifier"),
|
||||
payload.get("identifier"),
|
||||
)
|
||||
sender = (
|
||||
self._value(
|
||||
record.get("handle", {}).get("address")
|
||||
if isinstance(record.get("handle"), dict)
|
||||
else None,
|
||||
record.get("sender"),
|
||||
record.get("from"),
|
||||
record.get("address"),
|
||||
)
|
||||
or chat_identifier
|
||||
or chat_guid
|
||||
)
|
||||
if not (chat_guid or chat_identifier) and sender:
|
||||
chat_identifier = sender
|
||||
if not sender or not (chat_guid or chat_identifier) or not text:
|
||||
return web.json_response({"error": "missing message fields"}, status=400)
|
||||
|
||||
session_chat_id = chat_guid or chat_identifier
|
||||
is_group = bool(record.get("isGroup")) or (";+;" in (chat_guid or ""))
|
||||
source = self.build_source(
|
||||
chat_id=session_chat_id,
|
||||
chat_name=chat_identifier or sender,
|
||||
chat_type="group" if is_group else "dm",
|
||||
user_id=sender,
|
||||
user_name=sender,
|
||||
chat_id_alt=chat_identifier,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=payload,
|
||||
message_id=self._value(
|
||||
record.get("guid"),
|
||||
record.get("messageGuid"),
|
||||
record.get("id"),
|
||||
),
|
||||
reply_to_message_id=self._value(
|
||||
record.get("threadOriginatorGuid"),
|
||||
record.get("associatedMessageGuid"),
|
||||
),
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Fire-and-forget read receipt
|
||||
if self.send_read_receipts and session_chat_id:
|
||||
asyncio.create_task(self.mark_read(session_chat_id))
|
||||
|
||||
return web.Response(text="ok")
|
||||
|
|
@ -1767,8 +1767,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
# Get channel topic (if available).
|
||||
# For forum threads, inherit the parent forum's topic.
|
||||
chat_topic = self._get_effective_topic(interaction.channel, is_thread=is_thread)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=str(interaction.channel_id),
|
||||
|
|
@ -1842,6 +1843,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name
|
||||
|
||||
# Inherit forum topic when the thread was created inside a forum channel.
|
||||
_chan = getattr(interaction, "channel", None)
|
||||
chat_topic = self._get_effective_topic(_chan, is_thread=True) if _chan else None
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=thread_id,
|
||||
chat_name=chat_name,
|
||||
|
|
@ -1849,6 +1854,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
event = MessageEvent(
|
||||
|
|
@ -2134,6 +2140,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _get_effective_topic(self, channel: Any, is_thread: bool = False) -> Optional[str]:
|
||||
"""Return the channel topic, falling back to the parent forum's topic for forum threads."""
|
||||
topic = getattr(channel, "topic", None)
|
||||
if not topic and is_thread:
|
||||
parent = getattr(channel, "parent", None)
|
||||
if parent and self._is_forum_parent(parent):
|
||||
topic = getattr(parent, "topic", None)
|
||||
return topic
|
||||
|
||||
def _format_thread_chat_name(self, thread: Any) -> str:
|
||||
"""Build a readable chat name for thread-like Discord channels, including forum context when available."""
|
||||
thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread"))
|
||||
|
|
@ -2301,8 +2316,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't).
|
||||
# For threads whose parent is a forum channel, inherit the parent's topic
|
||||
# so forum descriptions (e.g. project instructions) appear in the session context.
|
||||
chat_topic = self._get_effective_topic(message.channel, is_thread=is_thread)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
|
|
@ -2365,7 +2382,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
MAX_DOC_BYTES = 32 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
|
|
@ -2389,9 +2406,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
# Inject text content for plain-text documents (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
if ext in (".md", ".txt", ".log") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
|
|
|
|||
|
|
@ -647,7 +647,11 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
# Use the timestamp from the RPC result as a pseudo message_id.
|
||||
# Signal doesn't have real message IDs, but the stream consumer
|
||||
# needs a truthy value to follow its edit→fallback path correctly.
|
||||
_msg_id = str(result.get("timestamp", "")) if isinstance(result, dict) else None
|
||||
return SendResult(success=True, message_id=_msg_id or None)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
def _track_sent_timestamp(self, rpc_result) -> None:
|
||||
|
|
@ -837,6 +841,11 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Public interface for stopping typing — called by base adapter's
|
||||
_keep_typing finally block to clean up platform-level typing tasks."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
|
|
@ -95,6 +95,12 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
# respond to ALL subsequent messages in that thread automatically.
|
||||
self._mentioned_threads: set = set()
|
||||
self._MENTIONED_THREADS_MAX = 5000
|
||||
# Assistant thread metadata keyed by (channel_id, thread_ts). Slack's
|
||||
# AI Assistant lifecycle events can arrive before/alongside message
|
||||
# events, and they carry the user/thread identity needed for stable
|
||||
# session + memory scoping.
|
||||
self._assistant_threads: Dict[Tuple[str, str], Dict[str, str]] = {}
|
||||
self._ASSISTANT_THREADS_MAX = 5000
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
|
|
@ -181,6 +187,14 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
@self._app.event("assistant_thread_started")
|
||||
async def handle_assistant_thread_started(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
@self._app.event("assistant_thread_context_changed")
|
||||
async def handle_assistant_thread_context_changed(event, say):
|
||||
await self._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
|
|
@ -755,6 +769,135 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
# ----- Internal handlers -----
|
||||
|
||||
def _assistant_thread_key(self, channel_id: str, thread_ts: str) -> Optional[Tuple[str, str]]:
|
||||
"""Return a stable cache key for Slack assistant thread metadata."""
|
||||
if not channel_id or not thread_ts:
|
||||
return None
|
||||
return (str(channel_id), str(thread_ts))
|
||||
|
||||
def _extract_assistant_thread_metadata(self, event: dict) -> Dict[str, str]:
|
||||
"""Extract Slack Assistant thread identity data from an event payload."""
|
||||
assistant_thread = event.get("assistant_thread") or {}
|
||||
context = assistant_thread.get("context") or event.get("context") or {}
|
||||
|
||||
channel_id = (
|
||||
assistant_thread.get("channel_id")
|
||||
or event.get("channel")
|
||||
or context.get("channel_id")
|
||||
or ""
|
||||
)
|
||||
thread_ts = (
|
||||
assistant_thread.get("thread_ts")
|
||||
or event.get("thread_ts")
|
||||
or event.get("message_ts")
|
||||
or ""
|
||||
)
|
||||
user_id = (
|
||||
assistant_thread.get("user_id")
|
||||
or event.get("user")
|
||||
or context.get("user_id")
|
||||
or ""
|
||||
)
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_thread.get("team_id")
|
||||
or ""
|
||||
)
|
||||
context_channel_id = context.get("channel_id") or ""
|
||||
|
||||
return {
|
||||
"channel_id": str(channel_id) if channel_id else "",
|
||||
"thread_ts": str(thread_ts) if thread_ts else "",
|
||||
"user_id": str(user_id) if user_id else "",
|
||||
"team_id": str(team_id) if team_id else "",
|
||||
"context_channel_id": str(context_channel_id) if context_channel_id else "",
|
||||
}
|
||||
|
||||
def _cache_assistant_thread_metadata(self, metadata: Dict[str, str]) -> None:
|
||||
"""Remember assistant thread identity data for later message events."""
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
key = self._assistant_thread_key(channel_id, thread_ts)
|
||||
if not key:
|
||||
return
|
||||
|
||||
existing = self._assistant_threads.get(key, {})
|
||||
merged = dict(existing)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
self._assistant_threads[key] = merged
|
||||
|
||||
# Evict oldest entries when the cache exceeds the limit
|
||||
if len(self._assistant_threads) > self._ASSISTANT_THREADS_MAX:
|
||||
excess = len(self._assistant_threads) - self._ASSISTANT_THREADS_MAX // 2
|
||||
for old_key in list(self._assistant_threads)[:excess]:
|
||||
del self._assistant_threads[old_key]
|
||||
|
||||
team_id = merged.get("team_id", "")
|
||||
if team_id and channel_id:
|
||||
self._channel_team[channel_id] = team_id
|
||||
|
||||
def _lookup_assistant_thread_metadata(
|
||||
self,
|
||||
event: dict,
|
||||
channel_id: str = "",
|
||||
thread_ts: str = "",
|
||||
) -> Dict[str, str]:
|
||||
"""Load cached assistant-thread metadata that matches the current event."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
if channel_id and not metadata.get("channel_id"):
|
||||
metadata["channel_id"] = channel_id
|
||||
if thread_ts and not metadata.get("thread_ts"):
|
||||
metadata["thread_ts"] = thread_ts
|
||||
|
||||
key = self._assistant_thread_key(
|
||||
metadata.get("channel_id", ""),
|
||||
metadata.get("thread_ts", ""),
|
||||
)
|
||||
cached = self._assistant_threads.get(key, {}) if key else {}
|
||||
if cached:
|
||||
merged = dict(cached)
|
||||
merged.update({k: v for k, v in metadata.items() if v})
|
||||
return merged
|
||||
return metadata
|
||||
|
||||
def _seed_assistant_thread_session(self, metadata: Dict[str, str]) -> None:
|
||||
"""Prime the session store so assistant threads get stable user scoping."""
|
||||
session_store = getattr(self, "_session_store", None)
|
||||
if not session_store:
|
||||
return
|
||||
|
||||
channel_id = metadata.get("channel_id", "")
|
||||
thread_ts = metadata.get("thread_ts", "")
|
||||
user_id = metadata.get("user_id", "")
|
||||
if not channel_id or not thread_ts or not user_id:
|
||||
return
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=channel_id,
|
||||
chat_name=channel_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
thread_id=thread_ts,
|
||||
chat_topic=metadata.get("context_channel_id") or None,
|
||||
)
|
||||
|
||||
try:
|
||||
session_store.get_or_create_session(source)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[Slack] Failed to seed assistant thread session for %s/%s",
|
||||
channel_id,
|
||||
thread_ts,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _handle_assistant_thread_lifecycle_event(self, event: dict) -> None:
|
||||
"""Handle Slack Assistant lifecycle events that carry user/thread identity."""
|
||||
metadata = self._extract_assistant_thread_metadata(event)
|
||||
self._cache_assistant_thread_metadata(metadata)
|
||||
self._seed_assistant_thread_session(metadata)
|
||||
|
||||
async def _handle_slack_message(self, event: dict) -> None:
|
||||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
|
|
@ -781,10 +924,21 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
text = event.get("text", "")
|
||||
user_id = event.get("user", "")
|
||||
channel_id = event.get("channel", "")
|
||||
ts = event.get("ts", "")
|
||||
team_id = event.get("team", "")
|
||||
assistant_meta = self._lookup_assistant_thread_metadata(
|
||||
event,
|
||||
channel_id=channel_id,
|
||||
thread_ts=event.get("thread_ts", ""),
|
||||
)
|
||||
user_id = event.get("user") or assistant_meta.get("user_id", "")
|
||||
if not channel_id:
|
||||
channel_id = assistant_meta.get("channel_id", "")
|
||||
team_id = (
|
||||
event.get("team")
|
||||
or event.get("team_id")
|
||||
or assistant_meta.get("team_id", "")
|
||||
)
|
||||
|
||||
# Track which workspace owns this channel
|
||||
if team_id and channel_id:
|
||||
|
|
@ -792,6 +946,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
# Determine if this is a DM or channel message
|
||||
channel_type = event.get("channel_type", "")
|
||||
if not channel_type and channel_id.startswith("D"):
|
||||
channel_type = "im"
|
||||
is_dm = channel_type == "im"
|
||||
|
||||
# Build thread_ts for session keying.
|
||||
|
|
@ -800,7 +956,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
# In DMs: only use the real thread_ts — top-level DMs should share
|
||||
# one continuous session, threaded DMs get their own session.
|
||||
if is_dm:
|
||||
thread_ts = event.get("thread_ts") # None for top-level DMs
|
||||
thread_ts = event.get("thread_ts") or assistant_meta.get("thread_ts") # None for top-level DMs
|
||||
else:
|
||||
thread_ts = event.get("thread_ts") or ts # ts fallback for channels
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,8 @@ if _config_path.exists():
|
|||
# Env var from .env takes precedence (already in os.environ).
|
||||
if "gateway_timeout" in _agent_cfg and "HERMES_AGENT_TIMEOUT" not in os.environ:
|
||||
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
|
||||
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
|
||||
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
|
|
@ -1073,6 +1075,7 @@ class GatewayRunner:
|
|||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"FEISHU_ALLOWED_USERS",
|
||||
"WECOM_ALLOWED_USERS",
|
||||
"BLUEBUBBLES_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
|
|
@ -1083,7 +1086,8 @@ class GatewayRunner:
|
|||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS",
|
||||
"FEISHU_ALLOW_ALL_USERS",
|
||||
"WECOM_ALLOW_ALL_USERS")
|
||||
"WECOM_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
|
|
@ -1654,6 +1658,13 @@ class GatewayRunner:
|
|||
adapter.gateway_runner = self # For cross-platform delivery
|
||||
return adapter
|
||||
|
||||
elif platform == Platform.BLUEBUBBLES:
|
||||
from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements
|
||||
if not check_bluebubbles_requirements():
|
||||
logger.warning("BlueBubbles: aiohttp/httpx missing or BLUEBUBBLES_SERVER_URL/BLUEBUBBLES_PASSWORD not configured")
|
||||
return None
|
||||
return BlueBubblesAdapter(config)
|
||||
|
||||
return None
|
||||
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
|
|
@ -1692,6 +1703,7 @@ class GatewayRunner:
|
|||
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOWED_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOWED_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
|
||||
|
|
@ -1706,6 +1718,7 @@ class GatewayRunner:
|
|||
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
|
||||
Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS",
|
||||
Platform.WECOM: "WECOM_ALLOW_ALL_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
# Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true)
|
||||
|
|
@ -1779,8 +1792,11 @@ class GatewayRunner:
|
|||
"""
|
||||
source = event.source
|
||||
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
# Internal events (e.g. background-process completion notifications)
|
||||
# are system-generated and must skip user authorization.
|
||||
if getattr(event, "internal", False):
|
||||
pass
|
||||
elif not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
# In DMs: offer pairing code. In groups: silently ignore.
|
||||
if source.chat_type == "dm" and self._get_unauthorized_dm_behavior(source.platform) == "pair":
|
||||
|
|
@ -5264,19 +5280,28 @@ class GatewayRunner:
|
|||
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:
|
||||
lines = [
|
||||
"📊 **Session Token Usage**",
|
||||
f"Prompt (input): {agent.session_prompt_tokens:,}",
|
||||
f"Completion (output): {agent.session_completion_tokens:,}",
|
||||
f"Total: {agent.session_total_tokens:,}",
|
||||
f"API calls: {agent.session_api_calls}",
|
||||
]
|
||||
lines = []
|
||||
|
||||
# Rate limits first (when available from provider headers)
|
||||
rl_state = agent.get_rate_limit_state()
|
||||
if rl_state and rl_state.has_data:
|
||||
from agent.rate_limit_tracker import format_rate_limit_compact
|
||||
lines.append(f"⏱️ **Rate Limits:** {format_rate_limit_compact(rl_state)}")
|
||||
lines.append("")
|
||||
|
||||
# Session token usage
|
||||
lines.append("📊 **Session Token Usage**")
|
||||
lines.append(f"Prompt (input): {agent.session_prompt_tokens:,}")
|
||||
lines.append(f"Completion (output): {agent.session_completion_tokens:,}")
|
||||
lines.append(f"Total: {agent.session_total_tokens:,}")
|
||||
lines.append(f"API calls: {agent.session_api_calls}")
|
||||
ctx = agent.context_compressor
|
||||
if ctx.last_prompt_tokens:
|
||||
pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0
|
||||
lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)")
|
||||
if ctx.compression_count:
|
||||
lines.append(f"Compressions: {ctx.compression_count}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# No running agent -- check session history for a rough count
|
||||
|
|
@ -5518,7 +5543,7 @@ class GatewayRunner:
|
|||
Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP,
|
||||
Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX,
|
||||
Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.LOCAL,
|
||||
Platform.FEISHU, Platform.WECOM, Platform.BLUEBUBBLES, Platform.LOCAL,
|
||||
})
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
|
|
@ -6158,6 +6183,7 @@ class GatewayRunner:
|
|||
text=synth_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=_source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info(
|
||||
"Process %s finished — injecting agent notification for session %s",
|
||||
|
|
@ -6308,7 +6334,15 @@ class GatewayRunner:
|
|||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
#
|
||||
# Per-platform overrides (display.tool_progress_overrides) take
|
||||
# priority over the global setting — e.g. Signal users can set
|
||||
# tool_progress to "off" while keeping Telegram on "all".
|
||||
_display_cfg = user_config.get("display", {})
|
||||
_overrides = _display_cfg.get("tool_progress_overrides", {})
|
||||
_raw_tp = _overrides.get(platform_key)
|
||||
if _raw_tp is None:
|
||||
_raw_tp = _display_cfg.get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
progress_mode = (
|
||||
|
|
@ -6412,6 +6446,18 @@ class GatewayRunner:
|
|||
if not adapter:
|
||||
return
|
||||
|
||||
# Skip tool progress for platforms that don't support message
|
||||
# editing (e.g. iMessage/BlueBubbles) — each progress update
|
||||
# would become a separate message bubble, which is noisy.
|
||||
from gateway.platforms.base import BasePlatformAdapter as _BaseAdapter
|
||||
if type(adapter).edit_message is _BaseAdapter.edit_message:
|
||||
while not progress_queue.empty():
|
||||
try:
|
||||
progress_queue.get_nowait()
|
||||
except Exception:
|
||||
break
|
||||
return
|
||||
|
||||
progress_lines = [] # Accumulated tool lines
|
||||
progress_msg_id = None # ID of the progress message to edit
|
||||
can_edit = True # False once an edit fails (platform doesn't support it)
|
||||
|
|
@ -7106,6 +7152,9 @@ class GatewayRunner:
|
|||
# Default 1800s (30 min inactivity). 0 = unlimited.
|
||||
_agent_timeout_raw = float(os.getenv("HERMES_AGENT_TIMEOUT", 1800))
|
||||
_agent_timeout = _agent_timeout_raw if _agent_timeout_raw > 0 else None
|
||||
_agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
_agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None
|
||||
_warning_fired = False
|
||||
loop = asyncio.get_event_loop()
|
||||
_executor_task = asyncio.ensure_future(
|
||||
loop.run_in_executor(None, run_sync)
|
||||
|
|
@ -7138,6 +7187,25 @@ class GatewayRunner:
|
|||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
# Staged warning: fire once before escalating to full timeout.
|
||||
if (not _warning_fired and _agent_warning is not None
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
_warn_adapter = self.adapters.get(source.platform)
|
||||
if _warn_adapter:
|
||||
_elapsed_warn = int(_agent_warning // 60) or 1
|
||||
_remaining_mins = int((_agent_timeout - _agent_warning) // 60) or 1
|
||||
try:
|
||||
await _warn_adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ No activity for {_elapsed_warn} min. "
|
||||
f"If the agent does not respond soon, it will "
|
||||
f"be timed out in {_remaining_mins} min. "
|
||||
f"You can continue waiting or use /reset.",
|
||||
metadata=_status_thread_metadata,
|
||||
)
|
||||
except Exception as _warn_err:
|
||||
logger.debug("Inactivity warning send error: %s", _warn_err)
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
|
|
|
|||
|
|
@ -193,6 +193,7 @@ _PII_SAFE_PLATFORMS = frozenset({
|
|||
Platform.WHATSAPP,
|
||||
Platform.SIGNAL,
|
||||
Platform.TELEGRAM,
|
||||
Platform.BLUEBUBBLES,
|
||||
})
|
||||
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
||||
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
||||
|
|
|
|||
|
|
@ -353,6 +353,17 @@ class GatewayStreamConsumer:
|
|||
self._message_id = result.message_id
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
elif result.success:
|
||||
# Platform accepted the message but returned no message_id
|
||||
# (e.g. Signal). Can't edit without an ID — switch to
|
||||
# fallback mode: suppress intermediate deltas, send only
|
||||
# the missing tail once the final response is ready.
|
||||
self._already_sent = True
|
||||
self._edit_supported = False
|
||||
self._fallback_prefix = self._clean_for_display(text)
|
||||
self._fallback_final_send = True
|
||||
# Sentinel prevents re-entering this branch on every delta
|
||||
self._message_id = "__no_edit__"
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
|
|
|
|||
|
|
@ -295,10 +295,16 @@ def _format_context_length(tokens: int) -> str:
|
|||
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
|
||||
if tokens >= 1_000_000:
|
||||
val = tokens / 1_000_000
|
||||
return f"{val:g}M"
|
||||
rounded = round(val)
|
||||
if abs(val - rounded) < 0.05:
|
||||
return f"{rounded}M"
|
||||
return f"{val:.1f}M"
|
||||
elif tokens >= 1_000:
|
||||
val = tokens / 1_000
|
||||
return f"{val:g}K"
|
||||
rounded = round(val)
|
||||
if abs(val - rounded) < 0.05:
|
||||
return f"{rounded}K"
|
||||
return f"{val:.1f}K"
|
||||
return str(tokens)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
|
||||
gateway_only=True, args_hint="[page]"),
|
||||
CommandDef("help", "Show available commands", "Info"),
|
||||
CommandDef("usage", "Show token usage for the current session", "Info"),
|
||||
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
|
||||
CommandDef("insights", "Show usage insights and analytics", "Info",
|
||||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ _EXTRA_ENV_KEYS = frozenset({
|
|||
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
|
||||
"FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN",
|
||||
"WECOM_BOT_ID", "WECOM_SECRET",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
|
||||
"MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE",
|
||||
|
|
@ -230,6 +231,10 @@ DEFAULT_CONFIG = {
|
|||
# (force on/off for all models), or a list of model-name substrings
|
||||
# to match (e.g. ["gpt", "codex", "gemini", "qwen"]).
|
||||
"tool_use_enforcement": "auto",
|
||||
# Staged inactivity warning: send a warning to the user at this
|
||||
# threshold before escalating to a full timeout. The warning fires
|
||||
# once per run and does not interrupt the agent. 0 = disable warning.
|
||||
"gateway_timeout_warning": 900,
|
||||
},
|
||||
|
||||
"terminal": {
|
||||
|
|
@ -392,6 +397,7 @@ DEFAULT_CONFIG = {
|
|||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
"tool_progress_command": False, # Enable /verbose command in messaging gateway
|
||||
"tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"}
|
||||
"tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands)
|
||||
},
|
||||
|
||||
|
|
@ -563,7 +569,7 @@ DEFAULT_CONFIG = {
|
|||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 12,
|
||||
"_config_version": 13,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -1119,6 +1125,27 @@ OPTIONAL_ENV_VARS = {
|
|||
"category": "messaging",
|
||||
"advanced": True,
|
||||
},
|
||||
"BLUEBUBBLES_SERVER_URL": {
|
||||
"description": "BlueBubbles server URL for iMessage integration (e.g. http://192.168.1.10:1234)",
|
||||
"prompt": "BlueBubbles server URL",
|
||||
"url": "https://bluebubbles.app/",
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"BLUEBUBBLES_PASSWORD": {
|
||||
"description": "BlueBubbles server password (from BlueBubbles Server → Settings → API)",
|
||||
"prompt": "BlueBubbles server password",
|
||||
"url": None,
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"BLUEBUBBLES_ALLOWED_USERS": {
|
||||
"description": "Comma-separated iMessage addresses (email or phone) allowed to use the bot",
|
||||
"prompt": "Allowed iMessage addresses (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"GATEWAY_ALLOW_ALL_USERS": {
|
||||
"description": "Allow all users to interact with messaging bots (true/false). Default: false.",
|
||||
"prompt": "Allow all users (true/false)",
|
||||
|
|
@ -1190,7 +1217,7 @@ OPTIONAL_ENV_VARS = {
|
|||
"category": "setting",
|
||||
},
|
||||
"SUDO_PASSWORD": {
|
||||
"description": "Sudo password for terminal commands requiring root access",
|
||||
"description": "Sudo password for terminal commands requiring root access; set to an explicit empty string to try empty without prompting",
|
||||
"prompt": "Sudo password",
|
||||
"url": None,
|
||||
"password": True,
|
||||
|
|
@ -1674,6 +1701,21 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
|||
ep = providers_dict[key]
|
||||
print(f" → {key}: {ep.get('api', '')}")
|
||||
|
||||
# ── Version 12 → 13: clear dead LLM_MODEL / OPENAI_MODEL from .env ──
|
||||
# These env vars were written by the old setup wizard but nothing reads
|
||||
# them anymore (config.yaml is the sole source of truth since March 2026).
|
||||
# Stale entries cause user confusion — see issue report.
|
||||
if current_ver < 13:
|
||||
for dead_var in ("LLM_MODEL", "OPENAI_MODEL"):
|
||||
try:
|
||||
old_val = get_env_value(dead_var)
|
||||
if old_val:
|
||||
save_env_value(dead_var, "")
|
||||
if not quiet:
|
||||
print(f" ✓ Cleared {dead_var} from .env (no longer used — config.yaml is source of truth)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
|
|
|
|||
337
hermes_cli/dump.py
Normal file
337
hermes_cli/dump.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Dump command for hermes CLI.
|
||||
|
||||
Outputs a compact, plain-text summary of the user's Hermes setup
|
||||
that can be copy-pasted into Discord/GitHub/Telegram for support context.
|
||||
No ANSI colors, no checkmarks — just data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_env_path, get_project_root, load_config
|
||||
from hermes_constants import display_hermes_home
|
||||
|
||||
|
||||
def _get_git_commit(project_root: Path) -> str:
|
||||
"""Return short git commit hash, or '(unknown)'."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short=8", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(project_root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
def _key_present(name: str) -> str:
|
||||
"""Return 'set' or 'not set' for an env var."""
|
||||
return "set" if os.getenv(name) else "not set"
|
||||
|
||||
|
||||
def _redact(value: str) -> str:
|
||||
"""Redact all but first 4 and last 4 chars."""
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) < 12:
|
||||
return "***"
|
||||
return value[:4] + "..." + value[-4:]
|
||||
|
||||
|
||||
def _gateway_status() -> str:
|
||||
"""Return a short gateway status string."""
|
||||
if sys.platform.startswith("linux"):
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc = get_service_name()
|
||||
except Exception:
|
||||
svc = "hermes-gateway"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["systemctl", "--user", "is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
elif sys.platform == "darwin":
|
||||
try:
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
r = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
return "N/A"
|
||||
|
||||
|
||||
def _count_skills(hermes_home: Path) -> int:
|
||||
"""Count installed skills."""
|
||||
skills_dir = hermes_home / "skills"
|
||||
if not skills_dir.is_dir():
|
||||
return 0
|
||||
count = 0
|
||||
for item in skills_dir.rglob("SKILL.md"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _count_mcp_servers(config: dict) -> int:
|
||||
"""Count configured MCP servers."""
|
||||
mcp = config.get("mcp", {})
|
||||
servers = mcp.get("servers", {})
|
||||
return len(servers)
|
||||
|
||||
|
||||
def _cron_summary(hermes_home: Path) -> str:
|
||||
"""Return cron jobs summary."""
|
||||
jobs_file = hermes_home / "cron" / "jobs.json"
|
||||
if not jobs_file.exists():
|
||||
return "0"
|
||||
try:
|
||||
with open(jobs_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
jobs = data.get("jobs", [])
|
||||
active = sum(1 for j in jobs if j.get("enabled", True))
|
||||
return f"{active} active / {len(jobs)} total"
|
||||
except Exception:
|
||||
return "(error reading)"
|
||||
|
||||
|
||||
def _configured_platforms() -> list[str]:
|
||||
"""Return list of configured messaging platform names."""
|
||||
checks = {
|
||||
"telegram": "TELEGRAM_BOT_TOKEN",
|
||||
"discord": "DISCORD_BOT_TOKEN",
|
||||
"slack": "SLACK_BOT_TOKEN",
|
||||
"whatsapp": "WHATSAPP_ENABLED",
|
||||
"signal": "SIGNAL_HTTP_URL",
|
||||
"email": "EMAIL_ADDRESS",
|
||||
"sms": "TWILIO_ACCOUNT_SID",
|
||||
"matrix": "MATRIX_HOMESERVER_URL",
|
||||
"mattermost": "MATTERMOST_URL",
|
||||
"homeassistant": "HASS_TOKEN",
|
||||
"dingtalk": "DINGTALK_CLIENT_ID",
|
||||
"feishu": "FEISHU_APP_ID",
|
||||
"wecom": "WECOM_BOT_ID",
|
||||
}
|
||||
return [name for name, env in checks.items() if os.getenv(env)]
|
||||
|
||||
|
||||
def _memory_provider(config: dict) -> str:
|
||||
"""Return the active memory provider name."""
|
||||
mem = config.get("memory", {})
|
||||
provider = mem.get("provider", "")
|
||||
return provider if provider else "built-in"
|
||||
|
||||
|
||||
def _get_model_and_provider(config: dict) -> tuple[str, str]:
|
||||
"""Extract model and provider from config."""
|
||||
model_cfg = config.get("model", "")
|
||||
if isinstance(model_cfg, dict):
|
||||
model = model_cfg.get("default") or model_cfg.get("model") or model_cfg.get("name") or "(not set)"
|
||||
provider = model_cfg.get("provider") or "(auto)"
|
||||
elif isinstance(model_cfg, str):
|
||||
model = model_cfg or "(not set)"
|
||||
provider = "(auto)"
|
||||
else:
|
||||
model = "(not set)"
|
||||
provider = "(auto)"
|
||||
return model, provider
|
||||
|
||||
|
||||
def _config_overrides(config: dict) -> dict[str, str]:
|
||||
"""Find non-default config values worth reporting.
|
||||
|
||||
Returns a flat dict of dotpath -> value for interesting overrides.
|
||||
"""
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
overrides = {}
|
||||
|
||||
# Sections with interesting user-facing overrides
|
||||
interesting_paths = [
|
||||
("agent", "max_turns"),
|
||||
("agent", "gateway_timeout"),
|
||||
("agent", "tool_use_enforcement"),
|
||||
("terminal", "backend"),
|
||||
("terminal", "docker_image"),
|
||||
("terminal", "persistent_shell"),
|
||||
("browser", "allow_private_urls"),
|
||||
("compression", "enabled"),
|
||||
("compression", "threshold"),
|
||||
("display", "streaming"),
|
||||
("display", "skin"),
|
||||
("display", "show_reasoning"),
|
||||
("smart_model_routing", "enabled"),
|
||||
("privacy", "redact_pii"),
|
||||
("tts", "provider"),
|
||||
]
|
||||
|
||||
for section, key in interesting_paths:
|
||||
default_section = DEFAULT_CONFIG.get(section, {})
|
||||
user_section = config.get(section, {})
|
||||
if not isinstance(default_section, dict) or not isinstance(user_section, dict):
|
||||
continue
|
||||
default_val = default_section.get(key)
|
||||
user_val = user_section.get(key)
|
||||
if user_val is not None and user_val != default_val:
|
||||
overrides[f"{section}.{key}"] = str(user_val)
|
||||
|
||||
# Toolsets (if different from default)
|
||||
default_toolsets = DEFAULT_CONFIG.get("toolsets", [])
|
||||
user_toolsets = config.get("toolsets", [])
|
||||
if user_toolsets != default_toolsets:
|
||||
overrides["toolsets"] = str(user_toolsets)
|
||||
|
||||
# Fallback providers
|
||||
fallbacks = config.get("fallback_providers", [])
|
||||
if fallbacks:
|
||||
overrides["fallback_providers"] = str(fallbacks)
|
||||
|
||||
return overrides
|
||||
|
||||
|
||||
def run_dump(args):
|
||||
"""Output a compact, copy-pasteable setup summary."""
|
||||
show_keys = getattr(args, "show_keys", False)
|
||||
|
||||
# Load env from .env file so key checks work
|
||||
from dotenv import load_dotenv
|
||||
env_path = get_env_path()
|
||||
if env_path.exists():
|
||||
try:
|
||||
load_dotenv(env_path, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
load_dotenv(env_path, encoding="latin-1")
|
||||
# Also try project .env as dev fallback
|
||||
load_dotenv(get_project_root() / ".env", override=False, encoding="utf-8")
|
||||
|
||||
project_root = get_project_root()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
try:
|
||||
from hermes_cli import __version__, __release_date__
|
||||
except ImportError:
|
||||
__version__ = "(unknown)"
|
||||
__release_date__ = ""
|
||||
|
||||
commit = _get_git_commit(project_root)
|
||||
|
||||
try:
|
||||
config = load_config()
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
model, provider = _get_model_and_provider(config)
|
||||
|
||||
# Profile
|
||||
try:
|
||||
from hermes_cli.profiles import get_active_profile_name
|
||||
profile = get_active_profile_name() or "(default)"
|
||||
except Exception:
|
||||
profile = "(default)"
|
||||
|
||||
# Terminal backend
|
||||
terminal_cfg = config.get("terminal", {})
|
||||
backend = terminal_cfg.get("backend", "local")
|
||||
|
||||
# OpenAI SDK version
|
||||
try:
|
||||
import openai
|
||||
openai_ver = openai.__version__
|
||||
except ImportError:
|
||||
openai_ver = "not installed"
|
||||
|
||||
# OS info
|
||||
os_info = f"{platform.system()} {platform.release()} {platform.machine()}"
|
||||
|
||||
lines = []
|
||||
lines.append("--- hermes dump ---")
|
||||
ver_str = f"{__version__}"
|
||||
if __release_date__:
|
||||
ver_str += f" ({__release_date__})"
|
||||
ver_str += f" [{commit}]"
|
||||
lines.append(f"version: {ver_str}")
|
||||
lines.append(f"os: {os_info}")
|
||||
lines.append(f"python: {sys.version.split()[0]}")
|
||||
lines.append(f"openai_sdk: {openai_ver}")
|
||||
lines.append(f"profile: {profile}")
|
||||
lines.append(f"hermes_home: {display_hermes_home()}")
|
||||
lines.append(f"model: {model}")
|
||||
lines.append(f"provider: {provider}")
|
||||
lines.append(f"terminal: {backend}")
|
||||
|
||||
# API keys
|
||||
lines.append("")
|
||||
lines.append("api_keys:")
|
||||
api_keys = [
|
||||
("OPENROUTER_API_KEY", "openrouter"),
|
||||
("OPENAI_API_KEY", "openai"),
|
||||
("ANTHROPIC_API_KEY", "anthropic"),
|
||||
("ANTHROPIC_TOKEN", "anthropic_token"),
|
||||
("NOUS_API_KEY", "nous"),
|
||||
("GLM_API_KEY", "glm/zai"),
|
||||
("ZAI_API_KEY", "zai"),
|
||||
("KIMI_API_KEY", "kimi"),
|
||||
("MINIMAX_API_KEY", "minimax"),
|
||||
("DEEPSEEK_API_KEY", "deepseek"),
|
||||
("DASHSCOPE_API_KEY", "dashscope"),
|
||||
("HF_TOKEN", "huggingface"),
|
||||
("AI_GATEWAY_API_KEY", "ai_gateway"),
|
||||
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
|
||||
("OPENCODE_GO_API_KEY", "opencode_go"),
|
||||
("KILOCODE_API_KEY", "kilocode"),
|
||||
("FIRECRAWL_API_KEY", "firecrawl"),
|
||||
("TAVILY_API_KEY", "tavily"),
|
||||
("BROWSERBASE_API_KEY", "browserbase"),
|
||||
("FAL_KEY", "fal"),
|
||||
("ELEVENLABS_API_KEY", "elevenlabs"),
|
||||
("GITHUB_TOKEN", "github"),
|
||||
]
|
||||
|
||||
for env_var, label in api_keys:
|
||||
val = os.getenv(env_var, "")
|
||||
if show_keys and val:
|
||||
display = _redact(val)
|
||||
else:
|
||||
display = "set" if val else "not set"
|
||||
lines.append(f" {label:<20} {display}")
|
||||
|
||||
# Features summary
|
||||
lines.append("")
|
||||
lines.append("features:")
|
||||
|
||||
toolsets = config.get("toolsets", ["hermes-cli"])
|
||||
lines.append(f" toolsets: {', '.join(toolsets) if toolsets else '(default)'}")
|
||||
lines.append(f" mcp_servers: {_count_mcp_servers(config)}")
|
||||
lines.append(f" memory_provider: {_memory_provider(config)}")
|
||||
lines.append(f" gateway: {_gateway_status()}")
|
||||
|
||||
platforms = _configured_platforms()
|
||||
lines.append(f" platforms: {', '.join(platforms) if platforms else 'none'}")
|
||||
lines.append(f" cron_jobs: {_cron_summary(hermes_home)}")
|
||||
lines.append(f" skills: {_count_skills(hermes_home)}")
|
||||
|
||||
# Config overrides (non-default values)
|
||||
overrides = _config_overrides(config)
|
||||
if overrides:
|
||||
lines.append("")
|
||||
lines.append("config_overrides:")
|
||||
for key, val in overrides.items():
|
||||
lines.append(f" {key}: {val}")
|
||||
|
||||
lines.append("--- end dump ---")
|
||||
|
||||
output = "\n".join(lines)
|
||||
print(output)
|
||||
|
|
@ -1588,6 +1588,34 @@ _PLATFORMS = [
|
|||
"help": "Chat ID for scheduled results and notifications."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "bluebubbles",
|
||||
"label": "BlueBubbles (iMessage)",
|
||||
"emoji": "💬",
|
||||
"token_var": "BLUEBUBBLES_SERVER_URL",
|
||||
"setup_instructions": [
|
||||
"1. Install BlueBubbles on a Mac that will act as your iMessage server:",
|
||||
" https://bluebubbles.app/",
|
||||
"2. Complete the BlueBubbles setup wizard — sign in with your Apple ID",
|
||||
"3. In BlueBubbles Settings → API, note the Server URL and password",
|
||||
"4. The server URL is typically http://<your-mac-ip>:1234",
|
||||
"5. Hermes connects via the BlueBubbles REST API and receives",
|
||||
" incoming messages via a local webhook",
|
||||
"6. To authorize users, use DM pairing: hermes pairing generate bluebubbles",
|
||||
" Share the code — the user sends it via iMessage to get approved",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "BLUEBUBBLES_SERVER_URL", "prompt": "BlueBubbles server URL (e.g. http://192.168.1.10:1234)", "password": False,
|
||||
"help": "The URL shown in BlueBubbles Settings → API."},
|
||||
{"name": "BLUEBUBBLES_PASSWORD", "prompt": "BlueBubbles server password", "password": True,
|
||||
"help": "The password shown in BlueBubbles Settings → API."},
|
||||
{"name": "BLUEBUBBLES_ALLOWED_USERS", "prompt": "Pre-authorized phone numbers or iMessage IDs (comma-separated, or leave empty for DM pairing)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Optional — pre-authorize specific users. Leave empty to use DM pairing instead (recommended)."},
|
||||
{"name": "BLUEBUBBLES_HOME_CHANNEL", "prompt": "Home channel (phone number or iMessage ID for cron/notifications, or empty)", "password": False,
|
||||
"help": "Phone number or Apple ID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1585,7 +1585,11 @@ def _model_flow_custom(config):
|
|||
f"Hermes will still save it."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}")
|
||||
suggested = probe["suggested_base_url"]
|
||||
if suggested.endswith("/v1"):
|
||||
print(f" If this server expects /v1 in the path, try base URL: {suggested}")
|
||||
else:
|
||||
print(f" If /v1 should not be in the base URL, try: {suggested}")
|
||||
|
||||
# Select model — use probe results when available, fall back to manual input
|
||||
model_name = ""
|
||||
|
|
@ -2750,6 +2754,12 @@ def cmd_doctor(args):
|
|||
run_doctor(args)
|
||||
|
||||
|
||||
def cmd_dump(args):
|
||||
"""Dump setup summary for support/debugging."""
|
||||
from hermes_cli.dump import run_dump
|
||||
run_dump(args)
|
||||
|
||||
|
||||
def cmd_config(args):
|
||||
"""Configuration management."""
|
||||
from hermes_cli.config import config_command
|
||||
|
|
@ -4843,6 +4853,22 @@ For more help on a command:
|
|||
help="Attempt to fix issues automatically"
|
||||
)
|
||||
doctor_parser.set_defaults(func=cmd_doctor)
|
||||
|
||||
# =========================================================================
|
||||
# dump command
|
||||
# =========================================================================
|
||||
dump_parser = subparsers.add_parser(
|
||||
"dump",
|
||||
help="Dump setup summary for support/debugging",
|
||||
description="Output a compact, plain-text summary of your Hermes setup "
|
||||
"that can be copy-pasted into Discord/GitHub for support context"
|
||||
)
|
||||
dump_parser.add_argument(
|
||||
"--show-keys",
|
||||
action="store_true",
|
||||
help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set"
|
||||
)
|
||||
dump_parser.set_defaults(func=cmd_dump)
|
||||
|
||||
# =========================================================================
|
||||
# config command
|
||||
|
|
|
|||
|
|
@ -537,8 +537,11 @@ def switch_model(
|
|||
)
|
||||
else:
|
||||
# --- Step c: On aggregator, convert vendor:model to vendor/model ---
|
||||
# Only convert when there's no slash — a slash means the name
|
||||
# is already in vendor/model format and the colon is a variant
|
||||
# tag (:free, :extended, :fast) that must be preserved.
|
||||
colon_pos = raw_input.find(":")
|
||||
if colon_pos > 0 and is_aggregator(current_provider):
|
||||
if colon_pos > 0 and "/" not in raw_input and is_aggregator(current_provider):
|
||||
left = raw_input[:colon_pos].strip().lower()
|
||||
right = raw_input[colon_pos + 1:].strip()
|
||||
if left and right:
|
||||
|
|
|
|||
|
|
@ -1532,7 +1532,7 @@ def probe_api_models(
|
|||
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models",
|
||||
"probed_url": tried[0] if tried else normalized.rstrip("/") + "/models",
|
||||
"resolved_base_url": normalized,
|
||||
"suggested_base_url": alternate_base if alternate_base != normalized else None,
|
||||
"used_fallback": False,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ _RESERVED_NAMES = frozenset({
|
|||
# Hermes subcommands that cannot be used as profile names/aliases
|
||||
_HERMES_SUBCOMMANDS = frozenset({
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"status", "cron", "doctor", "dump", "config", "pairing", "skills", "tools",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
"profile", "plugins", "honcho", "acp",
|
||||
})
|
||||
|
|
@ -1007,7 +1007,7 @@ _hermes_completion() {
|
|||
|
||||
# Top-level subcommands
|
||||
if [[ "$COMP_CWORD" == 1 ]]; then
|
||||
local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version"
|
||||
local commands="chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version"
|
||||
COMPREPLY=($(compgen -W "$commands" -- "$cur"))
|
||||
fi
|
||||
}
|
||||
|
|
@ -1032,7 +1032,7 @@ _hermes() {
|
|||
_arguments \\
|
||||
'-p[Profile name]:profile:($profiles)' \\
|
||||
'--profile[Profile name]:profile:($profiles)' \\
|
||||
'1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\
|
||||
'1:command:(chat model gateway setup status cron doctor dump config skills tools mcp sessions profile update version)' \\
|
||||
'*::arg:->args'
|
||||
|
||||
case $words[1] in
|
||||
|
|
|
|||
|
|
@ -2167,6 +2167,71 @@ def _setup_whatsapp():
|
|||
print_info("or personal self-chat) and pair via QR code.")
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
existing = get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
if existing:
|
||||
print_info("BlueBubbles: already configured")
|
||||
if not prompt_yes_no("Reconfigure BlueBubbles?", False):
|
||||
return
|
||||
|
||||
print_info("Connects Hermes to iMessage via BlueBubbles — a free, open-source")
|
||||
print_info("macOS server that bridges iMessage to any device.")
|
||||
print_info(" Requires a Mac running BlueBubbles Server v1.0.0+")
|
||||
print_info(" Download: https://bluebubbles.app/")
|
||||
print()
|
||||
print_info("In BlueBubbles Server → Settings → API, note your Server URL and Password.")
|
||||
print()
|
||||
|
||||
server_url = prompt("BlueBubbles server URL (e.g. http://192.168.1.10:1234)")
|
||||
if not server_url:
|
||||
print_warning("Server URL is required — skipping BlueBubbles setup")
|
||||
return
|
||||
save_env_value("BLUEBUBBLES_SERVER_URL", server_url.rstrip("/"))
|
||||
|
||||
password = prompt("BlueBubbles server password", password=True)
|
||||
if not password:
|
||||
print_warning("Password is required — skipping BlueBubbles setup")
|
||||
return
|
||||
save_env_value("BLUEBUBBLES_PASSWORD", password)
|
||||
print_success("BlueBubbles credentials saved")
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can message your bot")
|
||||
print_info(" Use iMessage addresses: email (user@icloud.com) or phone (+15551234567)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed iMessage addresses (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("BLUEBUBBLES_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("BlueBubbles allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set — anyone who can iMessage you can use the bot!")
|
||||
|
||||
print()
|
||||
print_info("📬 Home Channel: phone or email for cron job delivery and notifications.")
|
||||
print_info(" You can also set this later with /set-home in your iMessage chat.")
|
||||
home_channel = prompt("Home channel address (leave empty to set later)")
|
||||
if home_channel:
|
||||
save_env_value("BLUEBUBBLES_HOME_CHANNEL", home_channel)
|
||||
|
||||
print()
|
||||
print_info("Advanced settings (defaults are fine for most setups):")
|
||||
if prompt_yes_no("Configure webhook listener settings?", False):
|
||||
webhook_port = prompt("Webhook listener port (default: 8645)")
|
||||
if webhook_port:
|
||||
try:
|
||||
save_env_value("BLUEBUBBLES_WEBHOOK_PORT", str(int(webhook_port)))
|
||||
print_success(f"Webhook port set to {webhook_port}")
|
||||
except ValueError:
|
||||
print_warning("Invalid port number, using default 8645")
|
||||
|
||||
print()
|
||||
print_info("Requires the BlueBubbles Private API helper for typing indicators,")
|
||||
print_info("read receipts, and tapback reactions. Basic messaging works without it.")
|
||||
print_info(" Install: https://docs.bluebubbles.app/helper-bundle/installation")
|
||||
|
||||
|
||||
def _setup_webhooks():
|
||||
"""Configure webhook integration."""
|
||||
print_header("Webhooks")
|
||||
|
|
@ -2221,6 +2286,7 @@ _GATEWAY_PLATFORMS = [
|
|||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
]
|
||||
|
||||
|
|
@ -2264,6 +2330,7 @@ def setup_gateway(config: dict):
|
|||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
if any_messaging:
|
||||
|
|
@ -2283,6 +2350,8 @@ def setup_gateway(config: dict):
|
|||
missing_home.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN") and not get_env_value("SLACK_HOME_CHANNEL"):
|
||||
missing_home.append("Slack")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"):
|
||||
missing_home.append("BlueBubbles")
|
||||
|
||||
if missing_home:
|
||||
print()
|
||||
|
|
@ -2453,6 +2522,8 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
|
|||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL"):
|
||||
platforms.append("BlueBubbles")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ PLATFORMS = {
|
|||
"slack": "💼 Slack",
|
||||
"whatsapp": "📱 WhatsApp",
|
||||
"signal": "📡 Signal",
|
||||
"bluebubbles": "💬 BlueBubbles",
|
||||
"email": "📧 Email",
|
||||
"homeassistant": "🏠 Home Assistant",
|
||||
"mattermost": "💬 Mattermost",
|
||||
|
|
|
|||
|
|
@ -302,6 +302,7 @@ def show_status(args):
|
|||
"DingTalk": ("DINGTALK_CLIENT_ID", None),
|
||||
"Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"),
|
||||
"WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ PLATFORMS = {
|
|||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
|
|
|
|||
|
|
@ -1235,10 +1235,10 @@ class SessionDB:
|
|||
self._execute_write(_do)
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session, its child sessions, and all their messages.
|
||||
"""Delete a session and all its messages.
|
||||
|
||||
Child sessions (subagent runs, compression continuations) are deleted
|
||||
first to satisfy the ``parent_session_id`` foreign key constraint.
|
||||
Child sessions are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted, so they remain accessible independently.
|
||||
Returns True if the session was found and deleted.
|
||||
"""
|
||||
def _do(conn):
|
||||
|
|
@ -1247,15 +1247,12 @@ class SessionDB:
|
|||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
# Delete child sessions first (FK constraint)
|
||||
child_ids = [r[0] for r in conn.execute(
|
||||
"SELECT id FROM sessions WHERE parent_session_id = ?",
|
||||
# Orphan child sessions so FK constraint is satisfied
|
||||
conn.execute(
|
||||
"UPDATE sessions SET parent_session_id = NULL "
|
||||
"WHERE parent_session_id = ?",
|
||||
(session_id,),
|
||||
).fetchall()]
|
||||
for cid in child_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
|
||||
# Delete the session itself
|
||||
)
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
return True
|
||||
|
|
@ -1264,9 +1261,9 @@ class SessionDB:
|
|||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""Delete sessions older than N days. Returns count of deleted sessions.
|
||||
|
||||
Only prunes ended sessions (not active ones). Child sessions whose
|
||||
parents are being pruned are deleted first to satisfy the
|
||||
``parent_session_id`` foreign key constraint.
|
||||
Only prunes ended sessions (not active ones). Child sessions outside
|
||||
the prune window are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted.
|
||||
"""
|
||||
cutoff = time.time() - (older_than_days * 86400)
|
||||
|
||||
|
|
@ -1284,17 +1281,16 @@ class SessionDB:
|
|||
)
|
||||
session_ids = set(row["id"] for row in cursor.fetchall())
|
||||
|
||||
# Delete children first whose parents are in the prune set
|
||||
# (avoids FK constraint errors)
|
||||
for sid in list(session_ids):
|
||||
child_ids = [r[0] for r in conn.execute(
|
||||
"SELECT id FROM sessions WHERE parent_session_id = ?",
|
||||
(sid,),
|
||||
).fetchall()]
|
||||
for cid in child_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (cid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (cid,))
|
||||
session_ids.discard(cid) # don't double-delete
|
||||
if not session_ids:
|
||||
return 0
|
||||
|
||||
# Orphan any sessions whose parent is about to be deleted
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
conn.execute(
|
||||
f"UPDATE sessions SET parent_session_id = NULL "
|
||||
f"WHERE parent_session_id IN ({placeholders})",
|
||||
list(session_ids),
|
||||
)
|
||||
|
||||
for sid in session_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
|
|
|
|||
|
|
@ -569,7 +569,7 @@
|
|||
|
||||
# ── Activation: link config + auth + documents ────────────────────
|
||||
{
|
||||
system.activationScripts."hermes-agent-setup" = lib.stringAfter [ "users" "setupSecrets" ] ''
|
||||
system.activationScripts."hermes-agent-setup" = lib.stringAfter ([ "users" ] ++ lib.optional (config.system.activationScripts ? setupSecrets) "setupSecrets") ''
|
||||
# Ensure directories exist (activation runs before tmpfiles)
|
||||
mkdir -p ${cfg.stateDir}/.hermes
|
||||
mkdir -p ${cfg.stateDir}/home
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@
|
|||
};
|
||||
|
||||
runtimeDeps = with pkgs; [
|
||||
nodejs_20 ripgrep git openssh ffmpeg
|
||||
nodejs_20 ripgrep git openssh ffmpeg tirith
|
||||
];
|
||||
|
||||
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
|
||||
|
|
|
|||
|
|
@ -1803,30 +1803,34 @@ class Migrator:
|
|||
def migrate_cron_jobs(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
cron = config.get("cron") or {}
|
||||
if not cron:
|
||||
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
|
||||
return
|
||||
|
||||
# Archive the full cron config
|
||||
if self.archive_dir and self.execute:
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = self.archive_dir / "cron-config.json"
|
||||
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
|
||||
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
|
||||
else:
|
||||
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
|
||||
"archived", "Would archive cron config")
|
||||
|
||||
# Also check for cron store files
|
||||
cron_store = self.source_root / "cron"
|
||||
found_any = False
|
||||
|
||||
# Archive the full cron config when present
|
||||
if cron:
|
||||
found_any = True
|
||||
if self.archive_dir and self.execute:
|
||||
self.archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = self.archive_dir / "cron-config.json"
|
||||
dest.write_text(json.dumps(cron, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
self.record("cron-jobs", "openclaw.json cron.*", str(dest), "archived",
|
||||
"Cron config archived. Use 'hermes cron' to recreate jobs manually.")
|
||||
else:
|
||||
self.record("cron-jobs", "openclaw.json cron.*", "archive/cron-config.json",
|
||||
"archived", "Would archive cron config")
|
||||
|
||||
# Also check for cron store files even when config.cron is missing
|
||||
if cron_store.is_dir() and self.archive_dir:
|
||||
found_any = True
|
||||
dest_cron = self.archive_dir / "cron-store"
|
||||
if self.execute:
|
||||
shutil.copytree(cron_store, dest_cron, dirs_exist_ok=True)
|
||||
self.record("cron-jobs", str(cron_store), str(dest_cron), "archived",
|
||||
"Cron job store archived")
|
||||
|
||||
if not found_any:
|
||||
self.record("cron-jobs", None, None, "skipped", "No cron configuration found")
|
||||
|
||||
# ── Hooks ─────────────────────────────────────────────────
|
||||
def migrate_hooks_config(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
|
|
@ -2454,6 +2458,15 @@ class Migrator:
|
|||
notes.append(f"- **{item.kind}**: {item.reason}")
|
||||
notes.append("")
|
||||
|
||||
has_cron_config_archive = any(
|
||||
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-config.json")
|
||||
for i in self.items
|
||||
)
|
||||
has_cron_store_archive = any(
|
||||
i.kind == "cron-jobs" and i.status == "archived" and i.destination and i.destination.endswith("cron-store")
|
||||
for i in self.items
|
||||
)
|
||||
|
||||
notes.extend([
|
||||
"## IMPORTANT: Archive the OpenClaw Directory",
|
||||
"",
|
||||
|
|
@ -2475,7 +2488,14 @@ class Migrator:
|
|||
"- Run `hermes claw cleanup` to archive the OpenClaw directory (prevents state confusion)",
|
||||
"- Run `hermes setup` to configure any remaining settings",
|
||||
"- Run `hermes mcp list` to verify MCP servers were imported correctly",
|
||||
"- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)",
|
||||
])
|
||||
|
||||
if has_cron_config_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archive/cron-config.json)")
|
||||
elif has_cron_store_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)")
|
||||
|
||||
notes.extend([
|
||||
"- Run `hermes gateway install` if you need the gateway service",
|
||||
"- Review `~/.hermes/config.yaml` for any adjustments",
|
||||
"",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
# Hindsight Memory Provider
|
||||
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud and local (embedded) modes.
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval. Supports cloud, local embedded, and local external modes.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Cloud:** API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io)
|
||||
- **Local:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, MiniMax, or Ollama). Embeddings and reranking run locally — no additional API keys needed.
|
||||
- **Local Embedded:** API key for a supported LLM provider (OpenAI, Anthropic, Gemini, Groq, OpenRouter, MiniMax, Ollama, or any OpenAI-compatible endpoint). Embeddings and reranking run locally — no additional API keys needed.
|
||||
- **Local External:** A running Hindsight instance (Docker or self-hosted) reachable over HTTP.
|
||||
|
||||
## Setup
|
||||
|
||||
|
|
@ -21,17 +22,28 @@ hermes config set memory.provider hindsight
|
|||
echo "HINDSIGHT_API_KEY=your-key" >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Cloud Mode
|
||||
### Cloud
|
||||
|
||||
Connects to the Hindsight Cloud API. Requires an API key from [ui.hindsight.vectorize.io](https://ui.hindsight.vectorize.io).
|
||||
|
||||
### Local Mode
|
||||
### Local Embedded
|
||||
|
||||
Runs an embedded Hindsight server with built-in PostgreSQL. Requires an LLM API key (e.g. Groq, OpenAI, Anthropic) for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity.
|
||||
Hermes spins up a local Hindsight daemon with built-in PostgreSQL. Requires an LLM API key for memory extraction and synthesis. The daemon starts automatically in the background on first use and stops after 5 minutes of inactivity.
|
||||
|
||||
Supports any OpenAI-compatible LLM endpoint (llama.cpp, vLLM, LM Studio, etc.) — pick `openai_compatible` as the provider and enter the base URL.
|
||||
|
||||
Daemon startup logs: `~/.hermes/logs/hindsight-embed.log`
|
||||
Daemon runtime logs: `~/.hindsight/profiles/<profile>.log`
|
||||
|
||||
To open the Hindsight web UI (local embedded mode only):
|
||||
```bash
|
||||
hindsight-embed -p hermes ui start
|
||||
```
|
||||
|
||||
### Local External
|
||||
|
||||
Points the plugin at an existing Hindsight instance you're already running (Docker, self-hosted, etc.). No daemon management — just a URL and an optional API key.
|
||||
|
||||
## Config
|
||||
|
||||
Config file: `~/.hermes/hindsight/config.json`
|
||||
|
|
@ -40,40 +52,58 @@ Config file: `~/.hermes/hindsight/config.json`
|
|||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `mode` | `cloud` | `cloud` or `local` |
|
||||
| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud mode) |
|
||||
| `api_url` | `http://localhost:8888` | API URL (local mode, unused — daemon manages its own port) |
|
||||
| `mode` | `cloud` | `cloud`, `local_embedded`, or `local_external` |
|
||||
| `api_url` | `https://api.hindsight.vectorize.io` | API URL (cloud and local_external modes) |
|
||||
|
||||
### Memory
|
||||
### Memory Bank
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `bank_id` | `hermes` | Memory bank name |
|
||||
| `budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` |
|
||||
| `bank_mission` | — | Reflect mission (identity/framing for reflect reasoning). Applied via Banks API. |
|
||||
| `bank_retain_mission` | — | Retain mission (steers what gets extracted). Applied via Banks API. |
|
||||
|
||||
### Recall
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `recall_budget` | `mid` | Recall thoroughness: `low` / `mid` / `high` |
|
||||
| `recall_prefetch_method` | `recall` | Auto-recall method: `recall` (raw facts) or `reflect` (LLM synthesis) |
|
||||
| `recall_max_tokens` | `4096` | Maximum tokens for recall results |
|
||||
| `recall_max_input_chars` | `800` | Maximum input query length for auto-recall |
|
||||
| `recall_prompt_preamble` | — | Custom preamble for recalled memories in context |
|
||||
| `recall_tags` | — | Tags to filter when searching memories |
|
||||
| `recall_tags_match` | `any` | Tag matching mode: `any` / `all` / `any_strict` / `all_strict` |
|
||||
| `auto_recall` | `true` | Automatically recall memories before each turn |
|
||||
|
||||
### Retain
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `auto_retain` | `true` | Automatically retain conversation turns |
|
||||
| `retain_async` | `true` | Process retain asynchronously on the Hindsight server |
|
||||
| `retain_every_n_turns` | `1` | Retain every N turns (1 = every turn) |
|
||||
| `retain_context` | `conversation between Hermes Agent and the User` | Context label for retained memories |
|
||||
| `tags` | — | Tags applied when storing memories |
|
||||
|
||||
### Integration
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `memory_mode` | `hybrid` | How memories are integrated into the agent |
|
||||
| `prefetch_method` | `recall` | Method for automatic context injection |
|
||||
|
||||
**memory_mode:**
|
||||
- `hybrid` — automatic context injection + tools available to the LLM
|
||||
- `context` — automatic injection only, no tools exposed
|
||||
- `tools` — tools only, no automatic injection
|
||||
|
||||
**prefetch_method:**
|
||||
- `recall` — injects raw memory facts (fast)
|
||||
- `reflect` — injects LLM-synthesized summary (slower, more coherent)
|
||||
|
||||
### Local Mode LLM
|
||||
### Local Embedded LLM
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `llm_provider` | `openai` | LLM provider: `openai`, `anthropic`, `gemini`, `groq`, `minimax`, `ollama` |
|
||||
| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `openai/gpt-oss-120b`) |
|
||||
| `llm_base_url` | — | LLM Base URL override (e.g. `https://openrouter.ai/api/v1`) |
|
||||
| `llm_provider` | `openai` | `openai`, `anthropic`, `gemini`, `groq`, `openrouter`, `minimax`, `ollama`, `lmstudio`, `openai_compatible` |
|
||||
| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `qwen/qwen3.5-9b`) |
|
||||
| `llm_base_url` | — | Endpoint URL for `openai_compatible` (e.g. `http://192.168.1.10:8080/v1`) |
|
||||
|
||||
The LLM API key is stored in `~/.hermes/.env` as `HINDSIGHT_LLM_API_KEY`.
|
||||
|
||||
|
|
@ -97,4 +127,8 @@ Available in `hybrid` and `tools` memory modes:
|
|||
| `HINDSIGHT_API_URL` | Override API endpoint |
|
||||
| `HINDSIGHT_BANK_ID` | Override bank name |
|
||||
| `HINDSIGHT_BUDGET` | Override recall budget |
|
||||
| `HINDSIGHT_MODE` | Override mode (`cloud` / `local`) |
|
||||
| `HINDSIGHT_MODE` | Override mode (`cloud`, `local_embedded`, `local_external`) |
|
||||
|
||||
## Client Version
|
||||
|
||||
Requires `hindsight-client >= 0.4.22`. The plugin auto-upgrades on session start if an older version is detected.
|
||||
|
|
|
|||
|
|
@ -28,21 +28,25 @@ from hermes_constants import get_hermes_home
|
|||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.registry import tool_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_API_URL = "https://api.hindsight.vectorize.io"
|
||||
_DEFAULT_LOCAL_URL = "http://localhost:8888"
|
||||
_MIN_CLIENT_VERSION = "0.4.22"
|
||||
_VALID_BUDGETS = {"low", "mid", "high"}
|
||||
_PROVIDER_DEFAULT_MODELS = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-haiku-4-5",
|
||||
"gemini": "gemini-2.5-flash",
|
||||
"groq": "openai/gpt-oss-120b",
|
||||
"openrouter": "qwen/qwen3.5-9b",
|
||||
"minimax": "MiniMax-M2.7",
|
||||
"ollama": "gemma3:12b",
|
||||
"lmstudio": "local-model",
|
||||
"openai_compatible": "your-model-name",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -188,6 +192,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
self._bank_id = "hermes"
|
||||
self._budget = "mid"
|
||||
self._mode = "cloud"
|
||||
self._llm_base_url = ""
|
||||
self._memory_mode = "hybrid" # "context", "tools", or "hybrid"
|
||||
self._prefetch_method = "recall" # "recall" or "reflect"
|
||||
self._client = None
|
||||
|
|
@ -195,6 +200,31 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
self._sync_thread = None
|
||||
self._session_id = ""
|
||||
|
||||
# Tags
|
||||
self._tags: list[str] | None = None
|
||||
self._recall_tags: list[str] | None = None
|
||||
self._recall_tags_match = "any"
|
||||
|
||||
# Retain controls
|
||||
self._auto_retain = True
|
||||
self._retain_every_n_turns = 1
|
||||
self._retain_context = "conversation between Hermes Agent and the User"
|
||||
self._turn_counter = 0
|
||||
self._session_turns: list[str] = [] # accumulates ALL turns for the session
|
||||
|
||||
# Recall controls
|
||||
self._auto_recall = True
|
||||
self._recall_max_tokens = 4096
|
||||
self._recall_types: list[str] | None = None
|
||||
self._recall_prompt_preamble = ""
|
||||
self._recall_max_input_chars = 800
|
||||
|
||||
# Bank
|
||||
self._bank_mission = ""
|
||||
self._bank_retain_mission: str | None = None
|
||||
self._retain_async = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -204,7 +234,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
try:
|
||||
cfg = _load_config()
|
||||
mode = cfg.get("mode", "cloud")
|
||||
if mode == "local":
|
||||
if mode in ("local", "local_embedded", "local_external"):
|
||||
return True
|
||||
has_key = bool(cfg.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", ""))
|
||||
has_url = bool(cfg.get("api_url") or os.environ.get("HINDSIGHT_API_URL", ""))
|
||||
|
|
@ -228,73 +258,306 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
existing.update(values)
|
||||
config_path.write_text(json.dumps(existing, indent=2))
|
||||
|
||||
def post_setup(self, hermes_home: str, config: dict) -> None:
|
||||
"""Custom setup wizard — installs only the deps needed for the selected mode."""
|
||||
import getpass
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
from hermes_cli.memory_setup import _curses_select
|
||||
|
||||
print("\n Configuring Hindsight memory:\n")
|
||||
|
||||
# Step 1: Mode selection
|
||||
mode_items = [
|
||||
("Cloud", "Hindsight Cloud API (lightweight, just needs an API key)"),
|
||||
("Local Embedded", "Run Hindsight locally (downloads ~200MB, needs LLM key)"),
|
||||
("Local External", "Connect to an existing Hindsight instance"),
|
||||
]
|
||||
mode_idx = _curses_select(" Select mode", mode_items, default=0)
|
||||
mode = ["cloud", "local_embedded", "local_external"][mode_idx]
|
||||
|
||||
provider_config: dict = {"mode": mode}
|
||||
env_writes: dict = {}
|
||||
|
||||
# Step 2: Install/upgrade deps for selected mode
|
||||
_MIN_CLIENT_VERSION = "0.4.22"
|
||||
cloud_dep = f"hindsight-client>={_MIN_CLIENT_VERSION}"
|
||||
local_dep = "hindsight-all"
|
||||
if mode == "local_embedded":
|
||||
deps_to_install = [local_dep]
|
||||
elif mode == "local_external":
|
||||
deps_to_install = [cloud_dep]
|
||||
else:
|
||||
deps_to_install = [cloud_dep]
|
||||
|
||||
print(f"\n Checking dependencies...")
|
||||
uv_path = shutil.which("uv")
|
||||
if not uv_path:
|
||||
print(" ⚠ uv not found — install it: curl -LsSf https://astral.sh/uv/install.sh | sh")
|
||||
print(f" Then run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}")
|
||||
else:
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_path, "pip", "install", "--python", sys.executable, "--quiet", "--upgrade"] + deps_to_install,
|
||||
check=True, timeout=120, capture_output=True,
|
||||
)
|
||||
print(f" ✓ Dependencies up to date")
|
||||
except Exception as e:
|
||||
print(f" ⚠ Install failed: {e}")
|
||||
print(f" Run manually: uv pip install --python {sys.executable} {' '.join(deps_to_install)}")
|
||||
|
||||
# Step 3: Mode-specific config
|
||||
if mode == "cloud":
|
||||
print(f"\n Get your API key at https://ui.hindsight.vectorize.io\n")
|
||||
existing_key = os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
if existing_key:
|
||||
masked = f"...{existing_key[-4:]}" if len(existing_key) > 4 else "set"
|
||||
sys.stdout.write(f" API key (current: {masked}, blank to keep): ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
else:
|
||||
sys.stdout.write(" API key: ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if api_key:
|
||||
env_writes["HINDSIGHT_API_KEY"] = api_key
|
||||
|
||||
val = input(f" API URL [{_DEFAULT_API_URL}]: ").strip()
|
||||
if val:
|
||||
provider_config["api_url"] = val
|
||||
|
||||
elif mode == "local_external":
|
||||
val = input(f" Hindsight API URL [{_DEFAULT_LOCAL_URL}]: ").strip()
|
||||
provider_config["api_url"] = val or _DEFAULT_LOCAL_URL
|
||||
|
||||
sys.stdout.write(" API key (optional, blank to skip): ")
|
||||
sys.stdout.flush()
|
||||
api_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if api_key:
|
||||
env_writes["HINDSIGHT_API_KEY"] = api_key
|
||||
|
||||
else: # local_embedded
|
||||
providers_list = list(_PROVIDER_DEFAULT_MODELS.keys())
|
||||
llm_items = [
|
||||
(p, f"default model: {_PROVIDER_DEFAULT_MODELS[p]}")
|
||||
for p in providers_list
|
||||
]
|
||||
llm_idx = _curses_select(" Select LLM provider", llm_items, default=0)
|
||||
llm_provider = providers_list[llm_idx]
|
||||
|
||||
provider_config["llm_provider"] = llm_provider
|
||||
|
||||
if llm_provider == "openai_compatible":
|
||||
val = input(" LLM endpoint URL (e.g. http://192.168.1.10:8080/v1): ").strip()
|
||||
if val:
|
||||
provider_config["llm_base_url"] = val
|
||||
elif llm_provider == "openrouter":
|
||||
provider_config["llm_base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
default_model = _PROVIDER_DEFAULT_MODELS.get(llm_provider, "gpt-4o-mini")
|
||||
val = input(f" LLM model [{default_model}]: ").strip()
|
||||
provider_config["llm_model"] = val or default_model
|
||||
|
||||
sys.stdout.write(" LLM API key: ")
|
||||
sys.stdout.flush()
|
||||
llm_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
if llm_key:
|
||||
env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key
|
||||
|
||||
# Step 4: Save everything
|
||||
provider_config["bank_id"] = "hermes"
|
||||
provider_config["recall_budget"] = "mid"
|
||||
bank_id = "hermes"
|
||||
config["memory"]["provider"] = "hindsight"
|
||||
save_config(config)
|
||||
|
||||
self.save_config(provider_config, hermes_home)
|
||||
|
||||
if env_writes:
|
||||
env_path = Path(hermes_home) / ".env"
|
||||
env_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
existing_lines = []
|
||||
if env_path.exists():
|
||||
existing_lines = env_path.read_text().splitlines()
|
||||
updated_keys = set()
|
||||
new_lines = []
|
||||
for line in existing_lines:
|
||||
key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None
|
||||
if key_match and key_match in env_writes:
|
||||
new_lines.append(f"{key_match}={env_writes[key_match]}")
|
||||
updated_keys.add(key_match)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
for k, v in env_writes.items():
|
||||
if k not in updated_keys:
|
||||
new_lines.append(f"{k}={v}")
|
||||
env_path.write_text("\n".join(new_lines) + "\n")
|
||||
|
||||
print(f"\n ✓ Hindsight memory configured ({mode} mode)")
|
||||
if env_writes:
|
||||
print(f" API keys saved to .env")
|
||||
print(f"\n Start a new session to activate.\n")
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "mode", "description": "Cloud API or local embedded mode", "default": "cloud", "choices": ["cloud", "local"]},
|
||||
{"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}},
|
||||
{"key": "mode", "description": "Connection mode", "default": "cloud", "choices": ["cloud", "local_embedded", "local_external"]},
|
||||
# Cloud mode
|
||||
{"key": "api_url", "description": "Hindsight Cloud API URL", "default": _DEFAULT_API_URL, "when": {"mode": "cloud"}},
|
||||
{"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://ui.hindsight.vectorize.io", "when": {"mode": "cloud"}},
|
||||
{"key": "llm_provider", "description": "LLM provider for local mode", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "minimax", "ollama"], "when": {"mode": "local"}},
|
||||
{"key": "llm_api_key", "description": "LLM API key for local Hindsight", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local"}},
|
||||
{"key": "llm_base_url", "description": "LLM Base URL (e.g. for OpenRouter)", "default": "", "env_var": "HINDSIGHT_API_LLM_BASE_URL", "when": {"mode": "local"}},
|
||||
{"key": "llm_model", "description": "LLM model for local mode", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local"}},
|
||||
# Local external mode
|
||||
{"key": "api_url", "description": "Hindsight API URL", "default": _DEFAULT_LOCAL_URL, "when": {"mode": "local_external"}},
|
||||
{"key": "api_key", "description": "API key (optional)", "secret": True, "env_var": "HINDSIGHT_API_KEY", "when": {"mode": "local_external"}},
|
||||
# Local embedded mode
|
||||
{"key": "llm_provider", "description": "LLM provider", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "openrouter", "minimax", "ollama", "lmstudio", "openai_compatible"], "when": {"mode": "local_embedded"}},
|
||||
{"key": "llm_base_url", "description": "Endpoint URL (e.g. http://192.168.1.10:8080/v1)", "default": "", "when": {"mode": "local_embedded", "llm_provider": "openai_compatible"}},
|
||||
{"key": "llm_api_key", "description": "LLM API key (optional for openai_compatible)", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local_embedded"}},
|
||||
{"key": "llm_model", "description": "LLM model", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local_embedded"}},
|
||||
{"key": "bank_id", "description": "Memory bank name", "default": "hermes"},
|
||||
{"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
{"key": "bank_mission", "description": "Mission/purpose description for the memory bank"},
|
||||
{"key": "bank_retain_mission", "description": "Custom extraction prompt for memory retention"},
|
||||
{"key": "recall_budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
{"key": "memory_mode", "description": "Memory integration mode", "default": "hybrid", "choices": ["hybrid", "context", "tools"]},
|
||||
{"key": "prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]},
|
||||
{"key": "recall_prefetch_method", "description": "Auto-recall method", "default": "recall", "choices": ["recall", "reflect"]},
|
||||
{"key": "tags", "description": "Tags applied when storing memories (comma-separated)", "default": ""},
|
||||
{"key": "recall_tags", "description": "Tags to filter when searching memories (comma-separated)", "default": ""},
|
||||
{"key": "recall_tags_match", "description": "Tag matching mode for recall", "default": "any", "choices": ["any", "all", "any_strict", "all_strict"]},
|
||||
{"key": "auto_recall", "description": "Automatically recall memories before each turn", "default": True},
|
||||
{"key": "auto_retain", "description": "Automatically retain conversation turns", "default": True},
|
||||
{"key": "retain_every_n_turns", "description": "Retain every N turns (1 = every turn)", "default": 1},
|
||||
{"key": "retain_async","description": "Process retain asynchronously on the Hindsight server", "default": True},
|
||||
{"key": "retain_context", "description": "Context label for retained memories", "default": "conversation between Hermes Agent and the User"},
|
||||
{"key": "recall_max_tokens", "description": "Maximum tokens for recall results", "default": 4096},
|
||||
{"key": "recall_max_input_chars", "description": "Maximum input query length for auto-recall", "default": 800},
|
||||
{"key": "recall_prompt_preamble", "description": "Custom preamble for recalled memories in context"},
|
||||
]
|
||||
|
||||
def _get_client(self):
|
||||
"""Return the cached Hindsight client (created once, reused)."""
|
||||
if self._client is None:
|
||||
if self._mode == "local":
|
||||
if self._mode == "local_embedded":
|
||||
from hindsight import HindsightEmbedded
|
||||
# Disable __del__ on the class to prevent "attached to a
|
||||
# different loop" errors during GC — we handle cleanup in
|
||||
# shutdown() instead.
|
||||
HindsightEmbedded.__del__ = lambda self: None
|
||||
llm_provider = self._config.get("llm_provider", "")
|
||||
if llm_provider in ("openai_compatible", "openrouter"):
|
||||
llm_provider = "openai"
|
||||
logger.debug("Creating HindsightEmbedded client (profile=%s, provider=%s)",
|
||||
self._config.get("profile", "hermes"), llm_provider)
|
||||
kwargs = dict(
|
||||
profile=self._config.get("profile", "hermes"),
|
||||
llm_provider=self._config.get("llm_provider", ""),
|
||||
llm_api_key=self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_provider=llm_provider,
|
||||
llm_api_key=self._config.get("llmApiKey") or self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_model=self._config.get("llm_model", ""),
|
||||
)
|
||||
base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
if base_url:
|
||||
kwargs["llm_base_url"] = base_url
|
||||
if self._llm_base_url:
|
||||
kwargs["llm_base_url"] = self._llm_base_url
|
||||
self._client = HindsightEmbedded(**kwargs)
|
||||
else:
|
||||
from hindsight_client import Hindsight
|
||||
kwargs = {"base_url": self._api_url, "timeout": 30.0}
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
logger.debug("Creating Hindsight cloud client (url=%s, has_key=%s)",
|
||||
self._api_url, bool(self._api_key))
|
||||
self._client = Hindsight(**kwargs)
|
||||
return self._client
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._session_id = session_id
|
||||
|
||||
# Check client version and auto-upgrade if needed
|
||||
try:
|
||||
from importlib.metadata import version as pkg_version
|
||||
from packaging.version import Version
|
||||
installed = pkg_version("hindsight-client")
|
||||
if Version(installed) < Version(_MIN_CLIENT_VERSION):
|
||||
logger.warning("hindsight-client %s is outdated (need >=%s), attempting upgrade...",
|
||||
installed, _MIN_CLIENT_VERSION)
|
||||
import shutil, subprocess, sys
|
||||
uv_path = shutil.which("uv")
|
||||
if uv_path:
|
||||
try:
|
||||
subprocess.run(
|
||||
[uv_path, "pip", "install", "--python", sys.executable,
|
||||
"--quiet", "--upgrade", f"hindsight-client>={_MIN_CLIENT_VERSION}"],
|
||||
check=True, timeout=120, capture_output=True,
|
||||
)
|
||||
logger.info("hindsight-client upgraded to >=%s", _MIN_CLIENT_VERSION)
|
||||
except Exception as e:
|
||||
logger.warning("Auto-upgrade failed: %s. Run: uv pip install 'hindsight-client>=%s'",
|
||||
e, _MIN_CLIENT_VERSION)
|
||||
else:
|
||||
logger.warning("uv not found. Run: pip install 'hindsight-client>=%s'", _MIN_CLIENT_VERSION)
|
||||
except Exception:
|
||||
pass # packaging not available or other issue — proceed anyway
|
||||
|
||||
self._config = _load_config()
|
||||
self._mode = self._config.get("mode", "cloud")
|
||||
self._api_key = self._config.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
default_url = _DEFAULT_LOCAL_URL if self._mode == "local" else _DEFAULT_API_URL
|
||||
# "local" is a legacy alias for "local_embedded"
|
||||
if self._mode == "local":
|
||||
self._mode = "local_embedded"
|
||||
self._api_key = self._config.get("apiKey") or self._config.get("api_key") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
default_url = _DEFAULT_LOCAL_URL if self._mode in ("local_embedded", "local_external") else _DEFAULT_API_URL
|
||||
self._api_url = self._config.get("api_url") or os.environ.get("HINDSIGHT_API_URL", default_url)
|
||||
self._llm_base_url = self._config.get("llm_base_url", "")
|
||||
|
||||
banks = self._config.get("banks", {}).get("hermes", {})
|
||||
self._bank_id = self._config.get("bank_id") or banks.get("bankId", "hermes")
|
||||
budget = self._config.get("budget") or banks.get("budget", "mid")
|
||||
budget = self._config.get("recall_budget") or self._config.get("budget") or banks.get("budget", "mid")
|
||||
self._budget = budget if budget in _VALID_BUDGETS else "mid"
|
||||
|
||||
memory_mode = self._config.get("memory_mode", "hybrid")
|
||||
self._memory_mode = memory_mode if memory_mode in ("context", "tools", "hybrid") else "hybrid"
|
||||
|
||||
prefetch_method = self._config.get("prefetch_method", "recall")
|
||||
prefetch_method = self._config.get("recall_prefetch_method", "recall")
|
||||
self._prefetch_method = prefetch_method if prefetch_method in ("recall", "reflect") else "recall"
|
||||
|
||||
logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s",
|
||||
self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method)
|
||||
# Bank options
|
||||
self._bank_mission = self._config.get("bank_mission", "")
|
||||
self._bank_retain_mission = self._config.get("bank_retain_mission") or None
|
||||
|
||||
# Tags
|
||||
self._tags = self._config.get("tags") or None
|
||||
self._recall_tags = self._config.get("recall_tags") or None
|
||||
self._recall_tags_match = self._config.get("recall_tags_match", "any")
|
||||
|
||||
# Retain controls
|
||||
self._auto_retain = self._config.get("auto_retain", True)
|
||||
self._retain_every_n_turns = max(1, int(self._config.get("retain_every_n_turns", 1)))
|
||||
self._retain_context = self._config.get("retain_context", "conversation between Hermes Agent and the User")
|
||||
|
||||
# Recall controls
|
||||
self._auto_recall = self._config.get("auto_recall", True)
|
||||
self._recall_max_tokens = int(self._config.get("recall_max_tokens", 4096))
|
||||
self._recall_types = self._config.get("recall_types") or None
|
||||
self._recall_prompt_preamble = self._config.get("recall_prompt_preamble", "")
|
||||
self._recall_max_input_chars = int(self._config.get("recall_max_input_chars", 800))
|
||||
self._retain_async = self._config.get("retain_async", True)
|
||||
|
||||
_client_version = "unknown"
|
||||
try:
|
||||
from importlib.metadata import version as pkg_version
|
||||
_client_version = pkg_version("hindsight-client")
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Hindsight initialized: mode=%s, api_url=%s, bank=%s, budget=%s, memory_mode=%s, prefetch_method=%s, client=%s",
|
||||
self._mode, self._api_url, self._bank_id, self._budget, self._memory_mode, self._prefetch_method, _client_version)
|
||||
logger.debug("Hindsight config: auto_retain=%s, auto_recall=%s, retain_every_n=%d, "
|
||||
"retain_async=%s, retain_context=%s, "
|
||||
"recall_max_tokens=%d, recall_max_input_chars=%d, tags=%s, recall_tags=%s",
|
||||
self._auto_retain, self._auto_recall, self._retain_every_n_turns,
|
||||
self._retain_async, self._retain_context,
|
||||
self._recall_max_tokens, self._recall_max_input_chars,
|
||||
self._tags, self._recall_tags)
|
||||
|
||||
# For local mode, start the embedded daemon in the background so it
|
||||
# doesn't block the chat. Redirect stdout/stderr to a log file to
|
||||
# prevent rich startup output from spamming the terminal.
|
||||
if self._mode == "local":
|
||||
if self._mode == "local_embedded":
|
||||
def _start_daemon():
|
||||
import traceback
|
||||
log_dir = get_hermes_home() / "logs"
|
||||
|
|
@ -320,6 +583,8 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
current_provider = self._config.get("llm_provider", "")
|
||||
current_model = self._config.get("llm_model", "")
|
||||
current_base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
# Map openai_compatible/openrouter → openai for the daemon (OpenAI wire format)
|
||||
daemon_provider = "openai" if current_provider in ("openai_compatible", "openrouter") else current_provider
|
||||
|
||||
# Read saved profile config
|
||||
saved = {}
|
||||
|
|
@ -330,7 +595,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
saved[k.strip()] = v.strip()
|
||||
|
||||
config_changed = (
|
||||
saved.get("HINDSIGHT_API_LLM_PROVIDER") != current_provider or
|
||||
saved.get("HINDSIGHT_API_LLM_PROVIDER") != daemon_provider or
|
||||
saved.get("HINDSIGHT_API_LLM_MODEL") != current_model or
|
||||
saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key or
|
||||
saved.get("HINDSIGHT_API_LLM_BASE_URL", "") != current_base_url
|
||||
|
|
@ -340,7 +605,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
# Write updated profile .env
|
||||
profile_env.parent.mkdir(parents=True, exist_ok=True)
|
||||
env_lines = (
|
||||
f"HINDSIGHT_API_LLM_PROVIDER={current_provider}\n"
|
||||
f"HINDSIGHT_API_LLM_PROVIDER={daemon_provider}\n"
|
||||
f"HINDSIGHT_API_LLM_API_KEY={current_key}\n"
|
||||
f"HINDSIGHT_API_LLM_MODEL={current_model}\n"
|
||||
f"HINDSIGHT_API_LOG_LEVEL=info\n"
|
||||
|
|
@ -388,47 +653,118 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
logger.debug("Prefetch: waiting for background thread to complete")
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
logger.debug("Prefetch: no results available")
|
||||
return ""
|
||||
return f"## Hindsight Memory\n{result}"
|
||||
logger.debug("Prefetch: returning %d chars of context", len(result))
|
||||
header = self._recall_prompt_preamble or (
|
||||
"# Hindsight Memory (persistent cross-session context)\n"
|
||||
"Use this to answer questions about the user and prior sessions. "
|
||||
"Do not call tools to look up information that is already present here."
|
||||
)
|
||||
return f"{header}\n\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
if self._memory_mode == "tools":
|
||||
logger.debug("Prefetch: skipped (tools-only mode)")
|
||||
return
|
||||
if not self._auto_recall:
|
||||
logger.debug("Prefetch: skipped (auto_recall disabled)")
|
||||
return
|
||||
# Truncate query to max chars
|
||||
if self._recall_max_input_chars and len(query) > self._recall_max_input_chars:
|
||||
query = query[:self._recall_max_input_chars]
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = self._get_client()
|
||||
if self._prefetch_method == "reflect":
|
||||
logger.debug("Prefetch: calling reflect (bank=%s, query_len=%d)", self._bank_id, len(query))
|
||||
resp = _run_sync(client.areflect(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
text = resp.text or ""
|
||||
else:
|
||||
resp = _run_sync(client.arecall(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
text = "\n".join(r.text for r in resp.results if r.text) if resp.results else ""
|
||||
recall_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "query": query,
|
||||
"budget": self._budget, "max_tokens": self._recall_max_tokens,
|
||||
}
|
||||
if self._recall_tags:
|
||||
recall_kwargs["tags"] = self._recall_tags
|
||||
recall_kwargs["tags_match"] = self._recall_tags_match
|
||||
if self._recall_types:
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Prefetch: calling recall (bank=%s, query_len=%d, budget=%s)",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Prefetch: recall returned %d results", num_results)
|
||||
text = "\n".join(f"- {r.text}" for r in resp.results if r.text) if resp.results else ""
|
||||
if text:
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = text
|
||||
except Exception as e:
|
||||
logger.debug("Hindsight prefetch failed: %s", e)
|
||||
logger.debug("Hindsight prefetch failed: %s", e, exc_info=True)
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="hindsight-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Retain conversation turn in background (non-blocking)."""
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
"""Retain conversation turn in background (non-blocking).
|
||||
|
||||
Respects retain_every_n_turns for batching.
|
||||
"""
|
||||
if not self._auto_retain:
|
||||
logger.debug("sync_turn: skipped (auto_retain disabled)")
|
||||
return
|
||||
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": user_content, "timestamp": now},
|
||||
{"role": "assistant", "content": assistant_content, "timestamp": now},
|
||||
]
|
||||
|
||||
turn = json.dumps(messages)
|
||||
self._session_turns.append(turn)
|
||||
self._turn_counter += 1
|
||||
|
||||
# Only retain every N turns
|
||||
if self._turn_counter % self._retain_every_n_turns != 0:
|
||||
logger.debug("sync_turn: buffered turn %d (will retain at turn %d)",
|
||||
self._turn_counter, self._turn_counter + (self._retain_every_n_turns - self._turn_counter % self._retain_every_n_turns))
|
||||
return
|
||||
|
||||
logger.debug("sync_turn: retaining %d turns, total session content %d chars",
|
||||
len(self._session_turns), sum(len(t) for t in self._session_turns))
|
||||
# Send the ENTIRE session as a single JSON array (document_id deduplicates).
|
||||
# Each element in _session_turns is a JSON string of that turn's messages.
|
||||
content = "[" + ",".join(self._session_turns) + "]"
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
client = self._get_client()
|
||||
_run_sync(client.aretain(
|
||||
bank_id=self._bank_id, content=combined, context="conversation"
|
||||
item: dict = {
|
||||
"content": content,
|
||||
"context": self._retain_context,
|
||||
}
|
||||
if self._tags:
|
||||
item["tags"] = self._tags
|
||||
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
|
||||
self._bank_id, self._session_id, self._retain_async, len(content), len(self._session_turns))
|
||||
_run_sync(client.aretain_batch(
|
||||
bank_id=self._bank_id,
|
||||
items=[item],
|
||||
document_id=self._session_id,
|
||||
retain_async=self._retain_async,
|
||||
))
|
||||
logger.debug("Hindsight retain succeeded")
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight sync failed: %s", e)
|
||||
logger.warning("Hindsight sync failed: %s", e, exc_info=True)
|
||||
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
|
|
@ -453,12 +789,18 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
return tool_error("Missing required parameter: content")
|
||||
context = args.get("context")
|
||||
try:
|
||||
_run_sync(client.aretain(
|
||||
bank_id=self._bank_id, content=content, context=context
|
||||
))
|
||||
retain_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "content": content, "context": context,
|
||||
}
|
||||
if self._tags:
|
||||
retain_kwargs["tags"] = self._tags
|
||||
logger.debug("Tool hindsight_retain: bank=%s, content_len=%d, context=%s",
|
||||
self._bank_id, len(content), context)
|
||||
_run_sync(client.aretain(**retain_kwargs))
|
||||
logger.debug("Tool hindsight_retain: success")
|
||||
return json.dumps({"result": "Memory stored successfully."})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_retain failed: %s", e)
|
||||
logger.warning("hindsight_retain failed: %s", e, exc_info=True)
|
||||
return tool_error(f"Failed to store memory: {e}")
|
||||
|
||||
elif tool_name == "hindsight_recall":
|
||||
|
|
@ -466,15 +808,26 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
try:
|
||||
resp = _run_sync(client.arecall(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
))
|
||||
recall_kwargs: dict = {
|
||||
"bank_id": self._bank_id, "query": query, "budget": self._budget,
|
||||
"max_tokens": self._recall_max_tokens,
|
||||
}
|
||||
if self._recall_tags:
|
||||
recall_kwargs["tags"] = self._recall_tags
|
||||
recall_kwargs["tags_match"] = self._recall_tags_match
|
||||
if self._recall_types:
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Tool hindsight_recall: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Tool hindsight_recall: %d results", num_results)
|
||||
if not resp.results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
lines = [f"{i}. {r.text}" for i, r in enumerate(resp.results, 1)]
|
||||
return json.dumps({"result": "\n".join(lines)})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_recall failed: %s", e)
|
||||
logger.warning("hindsight_recall failed: %s", e, exc_info=True)
|
||||
return tool_error(f"Failed to search memory: {e}")
|
||||
|
||||
elif tool_name == "hindsight_reflect":
|
||||
|
|
@ -482,24 +835,28 @@ class HindsightMemoryProvider(MemoryProvider):
|
|||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
try:
|
||||
logger.debug("Tool hindsight_reflect: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = _run_sync(client.areflect(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
))
|
||||
logger.debug("Tool hindsight_reflect: response_len=%d", len(resp.text or ""))
|
||||
return json.dumps({"result": resp.text or "No relevant memories found."})
|
||||
except Exception as e:
|
||||
logger.warning("hindsight_reflect failed: %s", e)
|
||||
logger.warning("hindsight_reflect failed: %s", e, exc_info=True)
|
||||
return tool_error(f"Failed to reflect: {e}")
|
||||
|
||||
return tool_error(f"Unknown tool: {tool_name}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logger.debug("Hindsight shutdown: waiting for background threads")
|
||||
global _loop, _loop_thread
|
||||
for t in (self._prefetch_thread, self._sync_thread):
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
if self._client is not None:
|
||||
try:
|
||||
if self._mode == "local":
|
||||
if self._mode == "local_embedded":
|
||||
# Use the public close() API. The RuntimeError from
|
||||
# aiohttp's "attached to a different loop" is expected
|
||||
# and harmless — the daemon keeps running independently.
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@ name: hindsight
|
|||
version: 1.0.0
|
||||
description: "Hindsight — long-term memory with knowledge graph, entity resolution, and multi-strategy retrieval."
|
||||
pip_dependencies:
|
||||
- hindsight-client
|
||||
- hindsight-all
|
||||
requires_env:
|
||||
- HINDSIGHT_API_KEY
|
||||
- "hindsight-client>=0.4.22"
|
||||
requires_env: []
|
||||
hooks:
|
||||
- on_session_end
|
||||
|
|
|
|||
412
run_agent.py
412
run_agent.py
|
|
@ -66,7 +66,7 @@ from model_tools import (
|
|||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env, is_persistent_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
|
@ -77,6 +77,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
|||
# Agent internals extracted to agent/ package for modularity
|
||||
from agent.memory_manager import build_memory_context_block
|
||||
from agent.retry_utils import jittered_backoff
|
||||
from agent.error_classifier import classify_api_error, FailoverReason
|
||||
from agent.prompt_builder import (
|
||||
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||
|
|
@ -442,6 +443,13 @@ class AIAgent:
|
|||
for AI models that support function calling.
|
||||
"""
|
||||
|
||||
# ── Class-level context pressure dedup (survives across instances) ──
|
||||
# The gateway creates a new AIAgent per message, so instance-level flags
|
||||
# reset every time. This dict tracks {session_id: (warn_level, timestamp)}
|
||||
# to suppress duplicate warnings within a cooldown window.
|
||||
_context_pressure_last_warned: dict = {}
|
||||
_CONTEXT_PRESSURE_COOLDOWN = 300 # seconds between re-warning same session
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
|
@ -673,7 +681,8 @@ class AIAgent:
|
|||
# Context pressure warnings: notify the USER (not the LLM) as context
|
||||
# fills up. Purely informational — displayed in CLI output and sent via
|
||||
# status_callback for gateway platforms. Does NOT inject into messages.
|
||||
self._context_pressure_warned = False
|
||||
# Tiered: fires at 85% and again at 95% of compaction threshold.
|
||||
self._context_pressure_warned_at = 0.0 # highest tier already shown
|
||||
|
||||
# Activity tracking — updated on each API call, tool execution, and
|
||||
# stream chunk. Used by the gateway timeout handler to report what the
|
||||
|
|
@ -684,6 +693,10 @@ class AIAgent:
|
|||
self._current_tool: str | None = None
|
||||
self._api_call_count: int = 0
|
||||
|
||||
# Rate limit tracking — updated from x-ratelimit-* response headers
|
||||
# after each API call. Accessed by /usage slash command.
|
||||
self._rate_limit_state: Optional["RateLimitState"] = None
|
||||
|
||||
# Centralized logging — agent.log (INFO+) and errors.log (WARNING+)
|
||||
# both live under ~/.hermes/logs/. Idempotent, so gateway mode
|
||||
# (which creates a new AIAgent per message) won't duplicate handlers.
|
||||
|
|
@ -1687,9 +1700,25 @@ class AIAgent:
|
|||
return None
|
||||
|
||||
def _cleanup_task_resources(self, task_id: str) -> None:
|
||||
"""Clean up VM and browser resources for a given task."""
|
||||
"""Clean up VM and browser resources for a given task.
|
||||
|
||||
Skips ``cleanup_vm`` when the active terminal environment is marked
|
||||
persistent (``persistent_filesystem=True``) so that long-lived sandbox
|
||||
containers survive between turns. The idle reaper in
|
||||
``terminal_tool._cleanup_inactive_envs`` still tears them down once
|
||||
``terminal.lifetime_seconds`` is exceeded. Non-persistent backends are
|
||||
torn down per-turn as before to prevent resource leakage (the original
|
||||
intent of this hook for the Morph backend, see commit fbd3a2fd).
|
||||
"""
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
if is_persistent_env(task_id):
|
||||
if self.verbose_logging:
|
||||
logging.debug(
|
||||
f"Skipping per-turn cleanup_vm for persistent env {task_id}; "
|
||||
f"idle reaper will handle it."
|
||||
)
|
||||
else:
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {task_id}: {e}")
|
||||
|
|
@ -2521,6 +2550,29 @@ class AIAgent:
|
|||
self._last_activity_ts = time.time()
|
||||
self._last_activity_desc = desc
|
||||
|
||||
def _capture_rate_limits(self, http_response: Any) -> None:
|
||||
"""Parse x-ratelimit-* headers from an HTTP response and cache the state.
|
||||
|
||||
Called after each streaming API call. The httpx Response object is
|
||||
available on the OpenAI SDK Stream via ``stream.response``.
|
||||
"""
|
||||
if http_response is None:
|
||||
return
|
||||
headers = getattr(http_response, "headers", None)
|
||||
if not headers:
|
||||
return
|
||||
try:
|
||||
from agent.rate_limit_tracker import parse_rate_limit_headers
|
||||
state = parse_rate_limit_headers(headers, provider=self.provider)
|
||||
if state is not None:
|
||||
self._rate_limit_state = state
|
||||
except Exception:
|
||||
pass # Never let header parsing break the agent loop
|
||||
|
||||
def get_rate_limit_state(self):
|
||||
"""Return the last captured RateLimitState, or None."""
|
||||
return self._rate_limit_state
|
||||
|
||||
def get_activity_summary(self) -> dict:
|
||||
"""Return a snapshot of the agent's current activity for diagnostics.
|
||||
|
||||
|
|
@ -4375,6 +4427,11 @@ class AIAgent:
|
|||
self._touch_activity("waiting for provider response (streaming)")
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
# Capture rate limit headers from the initial HTTP response.
|
||||
# The OpenAI SDK Stream object exposes the underlying httpx
|
||||
# response via .response before any chunks are consumed.
|
||||
self._capture_rate_limits(getattr(stream, "response", None))
|
||||
|
||||
content_parts: list = []
|
||||
tool_calls_acc: dict = {}
|
||||
tool_gen_notified: set = set()
|
||||
|
|
@ -4728,18 +4785,25 @@ class AIAgent:
|
|||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
_stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0))
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
# Local providers (Ollama, oMLX, llama-cpp) can take 300+ seconds
|
||||
# for prefill on large contexts. Disable the stale detector unless
|
||||
# the user explicitly set HERMES_STREAM_STALE_TIMEOUT.
|
||||
if _stream_stale_timeout_base == 180.0 and self.base_url and is_local_endpoint(self.base_url):
|
||||
_stream_stale_timeout = float("inf")
|
||||
logger.debug("Local provider detected (%s) — stale stream timeout disabled", self.base_url)
|
||||
else:
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
else:
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
|
|
@ -5864,7 +5928,7 @@ class AIAgent:
|
|||
tools=[memory_tool_def],
|
||||
temperature=0.3,
|
||||
max_tokens=5120,
|
||||
timeout=30.0,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
)
|
||||
except RuntimeError:
|
||||
_aux_available = False
|
||||
|
|
@ -5896,7 +5960,10 @@ class AIAgent:
|
|||
"temperature": 0.3,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(**api_kwargs, timeout=30.0)
|
||||
from agent.auxiliary_client import _get_task_timeout
|
||||
response = self._ensure_primary_openai_client(reason="flush_memories").chat.completions.create(
|
||||
**api_kwargs, timeout=_get_task_timeout("flush_memories")
|
||||
)
|
||||
|
||||
# Extract tool calls from the response, handling all API formats
|
||||
tool_calls = []
|
||||
|
|
@ -6003,6 +6070,15 @@ class AIAgent:
|
|||
except Exception as e:
|
||||
logger.warning("Session DB compression split failed — new session will NOT be indexed: %s", e)
|
||||
|
||||
# Warn on repeated compressions (quality degrades with each pass)
|
||||
_cc = self.context_compressor.compression_count
|
||||
if _cc >= 2:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Session compressed {_cc} times — "
|
||||
f"accuracy may degrade. Consider /new to start fresh.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Update token estimate after compaction so pressure calculations
|
||||
# use the post-compression count, not the stale pre-compression one.
|
||||
_compressed_est = (
|
||||
|
|
@ -6015,12 +6091,16 @@ class AIAgent:
|
|||
# Only reset the pressure warning if compression actually brought
|
||||
# us below the warning level (85% of threshold). When compression
|
||||
# can't reduce enough (e.g. threshold is very low, or system prompt
|
||||
# alone exceeds the warning level), keep the flag set to prevent
|
||||
# alone exceeds the warning level), keep the tier set to prevent
|
||||
# spamming the user with repeated warnings every loop iteration.
|
||||
if self.context_compressor.threshold_tokens > 0:
|
||||
_post_progress = _compressed_est / self.context_compressor.threshold_tokens
|
||||
if _post_progress < 0.85:
|
||||
self._context_pressure_warned = False
|
||||
self._context_pressure_warned_at = 0.0
|
||||
# Clear class-level dedup for this session so a fresh
|
||||
# warning cycle can start if context grows again.
|
||||
_sid = self.session_id or "default"
|
||||
AIAgent._context_pressure_last_warned.pop(_sid, None)
|
||||
|
||||
# Clear the file-read dedup cache. After compression the original
|
||||
# read content is summarised away — if the model re-reads the same
|
||||
|
|
@ -7202,6 +7282,7 @@ class AIAgent:
|
|||
length_continue_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
|
@ -7226,6 +7307,7 @@ class AIAgent:
|
|||
# Check for interrupt request (e.g., user sent new message)
|
||||
if self._interrupt_requested:
|
||||
interrupted = True
|
||||
_turn_exit_reason = "interrupted_by_user"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print("\n⚡ Breaking out of tool loop due to interrupt...")
|
||||
break
|
||||
|
|
@ -7234,6 +7316,7 @@ class AIAgent:
|
|||
self._api_call_count = api_call_count
|
||||
self._touch_activity(f"starting API call #{api_call_count}")
|
||||
if not self.iteration_budget.consume():
|
||||
_turn_exit_reason = "budget_exhausted"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
break
|
||||
|
|
@ -7938,6 +8021,25 @@ class AIAgent:
|
|||
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
error_context = self._extract_api_error_context(api_error)
|
||||
|
||||
# ── Classify the error for structured recovery decisions ──
|
||||
_compressor = getattr(self, "context_compressor", None)
|
||||
_ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000
|
||||
classified = classify_api_error(
|
||||
api_error,
|
||||
provider=getattr(self, "provider", "") or "",
|
||||
model=getattr(self, "model", "") or "",
|
||||
approx_tokens=approx_tokens,
|
||||
context_length=_ctx_len,
|
||||
num_messages=len(api_messages) if api_messages else 0,
|
||||
)
|
||||
logger.debug(
|
||||
"Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s",
|
||||
classified.reason.value, classified.status_code,
|
||||
classified.retryable, classified.should_compress,
|
||||
classified.should_rotate_credential, classified.should_fallback,
|
||||
)
|
||||
|
||||
recovered_with_pool, has_retried_429 = self._recover_with_credential_pool(
|
||||
status_code=status_code,
|
||||
has_retried_429=has_retried_429,
|
||||
|
|
@ -8000,27 +8102,24 @@ class AIAgent:
|
|||
# from all messages so the next retry sends no thinking
|
||||
# blocks at all. One-shot — don't retry infinitely.
|
||||
if (
|
||||
self.api_mode == "anthropic_messages"
|
||||
and status_code == 400
|
||||
classified.reason == FailoverReason.thinking_signature
|
||||
and not thinking_sig_retry_attempted
|
||||
):
|
||||
_err_msg_lower = str(api_error).lower()
|
||||
if "signature" in _err_msg_lower and "thinking" in _err_msg_lower:
|
||||
thinking_sig_retry_attempted = True
|
||||
for _m in messages:
|
||||
if isinstance(_m, dict):
|
||||
_m.pop("reasoning_details", None)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
|
||||
f"stripped all thinking blocks, retrying...",
|
||||
force=True,
|
||||
)
|
||||
logging.warning(
|
||||
"%sThinking block signature recovery: stripped "
|
||||
"reasoning_details from %d messages",
|
||||
self.log_prefix, len(messages),
|
||||
)
|
||||
continue
|
||||
thinking_sig_retry_attempted = True
|
||||
for _m in messages:
|
||||
if isinstance(_m, dict):
|
||||
_m.pop("reasoning_details", None)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
|
||||
f"stripped all thinking blocks, retrying...",
|
||||
force=True,
|
||||
)
|
||||
logging.warning(
|
||||
"%sThinking block signature recovery: stripped "
|
||||
"reasoning_details from %d messages",
|
||||
self.log_prefix, len(messages),
|
||||
)
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
|
|
@ -8077,14 +8176,7 @@ class AIAgent:
|
|||
# is NOT a transient rate limit — retrying or switching
|
||||
# credentials won't help. Reduce context to 200k (the
|
||||
# standard tier) and compress.
|
||||
# Only applies to Sonnet — Opus 1M is general access.
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in self.model.lower()
|
||||
)
|
||||
if _is_long_context_tier_error:
|
||||
if classified.reason == FailoverReason.long_context_tier:
|
||||
_reduced_ctx = 200000
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
|
|
@ -8129,13 +8221,9 @@ class AIAgent:
|
|||
# When a fallback model is configured, switch immediately instead
|
||||
# of burning through retries with exponential backoff -- the
|
||||
# primary provider won't recover within the retry window.
|
||||
is_rate_limited = (
|
||||
status_code == 429
|
||||
or "rate limit" in error_msg
|
||||
or "too many requests" in error_msg
|
||||
or "rate_limit" in error_msg
|
||||
or "usage limit" in error_msg
|
||||
or "quota" in error_msg
|
||||
is_rate_limited = classified.reason in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.billing,
|
||||
)
|
||||
if is_rate_limited and self._fallback_index < len(self._fallback_chain):
|
||||
# Don't eagerly fallback if credential pool rotation may
|
||||
|
|
@ -8151,10 +8239,7 @@ class AIAgent:
|
|||
continue
|
||||
|
||||
is_payload_too_large = (
|
||||
status_code == 413
|
||||
or 'request entity too large' in error_msg
|
||||
or 'payload too large' in error_msg
|
||||
or 'error code: 413' in error_msg
|
||||
classified.reason == FailoverReason.payload_too_large
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
|
|
@ -8198,64 +8283,12 @@ class AIAgent:
|
|||
}
|
||||
|
||||
# Check for context-length errors BEFORE generic 4xx handler.
|
||||
# Local backends (LM Studio, Ollama, llama.cpp) often return
|
||||
# HTTP 400 with messages like "Context size has been exceeded"
|
||||
# which must trigger compression, not an immediate abort.
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'context size', 'maximum context',
|
||||
'token limit', 'too many tokens', 'reduce the length',
|
||||
'exceeds the limit', 'context window',
|
||||
'request entity too large', # OpenRouter/Nous 413 safety net
|
||||
'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum"
|
||||
'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording
|
||||
])
|
||||
|
||||
# Fallback heuristic: Anthropic sometimes returns a generic
|
||||
# 400 invalid_request_error with just "Error" as the message
|
||||
# when the context is too large. If the error message is very
|
||||
# short/generic AND the session is large, treat it as a
|
||||
# probable context-length error and attempt compression rather
|
||||
# than aborting. This prevents an infinite failure loop where
|
||||
# each failed message gets persisted, making the session even
|
||||
# larger. (#1630)
|
||||
if not is_context_length_error and status_code == 400:
|
||||
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
|
||||
is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80
|
||||
is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error"
|
||||
if is_large_session and is_generic_error:
|
||||
is_context_length_error = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Generic 400 with large session "
|
||||
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
|
||||
f"treating as probable context overflow.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Server disconnects on large sessions are often caused by
|
||||
# the request exceeding the provider's context/payload limit
|
||||
# without a proper HTTP error response. Treat these as
|
||||
# context-length errors to trigger compression rather than
|
||||
# burning through retries that will all fail the same way.
|
||||
# This breaks the death spiral: disconnect → no token data
|
||||
# → no compression → bigger session → more disconnects.
|
||||
# (#2153)
|
||||
if not is_context_length_error and not status_code:
|
||||
_is_server_disconnect = (
|
||||
'server disconnected' in error_msg
|
||||
or 'peer closed connection' in error_msg
|
||||
or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError')
|
||||
)
|
||||
if _is_server_disconnect:
|
||||
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
|
||||
_is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200
|
||||
if _is_large:
|
||||
is_context_length_error = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Server disconnected with large session "
|
||||
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
|
||||
f"treating as context-length error, attempting compression.",
|
||||
force=True,
|
||||
)
|
||||
# The classifier detects context overflow from: explicit error
|
||||
# messages, generic 400 + large session heuristic (#1630), and
|
||||
# server disconnect + large session pattern (#2153).
|
||||
is_context_length_error = (
|
||||
classified.reason == FailoverReason.context_overflow
|
||||
)
|
||||
|
||||
if is_context_length_error:
|
||||
compressor = self.context_compressor
|
||||
|
|
@ -8327,35 +8360,30 @@ class AIAgent:
|
|||
"partial": True
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
# Note: 413 and context-length errors are excluded — handled above.
|
||||
# 429 (rate limit) is transient and MUST be retried with backoff.
|
||||
# 529 (Anthropic overloaded) is also transient.
|
||||
# Also catch local validation errors (ValueError, TypeError) — these
|
||||
# are programming bugs, not transient failures.
|
||||
# Exclude UnicodeEncodeError — it's a ValueError subclass but is
|
||||
# handled separately by the surrogate sanitization path above.
|
||||
_RETRYABLE_STATUS_CODES = {413, 429, 529}
|
||||
# Check for non-retryable client errors. The classifier
|
||||
# already accounts for 413, 429, 529 (transient), context
|
||||
# overflow, and generic-400 heuristics. Local validation
|
||||
# errors (ValueError, TypeError) are programming bugs.
|
||||
is_local_validation_error = (
|
||||
isinstance(api_error, (ValueError, TypeError))
|
||||
and not isinstance(api_error, UnicodeEncodeError)
|
||||
)
|
||||
# Detect generic 400s from Anthropic OAuth (transient server-side failures).
|
||||
# Real invalid_request_error responses include a descriptive message;
|
||||
# transient ones contain only "Error" or are empty. (ref: issue #1608)
|
||||
_err_body = getattr(api_error, "body", None) or {}
|
||||
_err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "")
|
||||
_is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", ""))
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400
|
||||
is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [
|
||||
'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
'is not a valid model', 'invalid model', 'model not found',
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])) and not is_context_length_error
|
||||
is_client_error = (
|
||||
is_local_validation_error
|
||||
or (
|
||||
not classified.retryable
|
||||
and not classified.should_compress
|
||||
and classified.reason not in (
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.billing,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.context_overflow,
|
||||
FailoverReason.payload_too_large,
|
||||
FailoverReason.long_context_tier,
|
||||
FailoverReason.thinking_signature,
|
||||
)
|
||||
)
|
||||
) and not is_context_length_error
|
||||
|
||||
if is_client_error:
|
||||
# Try fallback before aborting — a different provider
|
||||
|
|
@ -8375,7 +8403,7 @@ class AIAgent:
|
|||
self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True)
|
||||
self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True)
|
||||
# Actionable guidance for common auth errors
|
||||
if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg:
|
||||
if classified.is_auth or classified.reason == FailoverReason.billing:
|
||||
if _provider == "openai-codex" and status_code == 401:
|
||||
self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True)
|
||||
self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True)
|
||||
|
|
@ -8535,6 +8563,7 @@ class AIAgent:
|
|||
|
||||
# If the API call was interrupted, skip response processing
|
||||
if interrupted:
|
||||
_turn_exit_reason = "interrupted_during_api_call"
|
||||
break
|
||||
|
||||
if restart_with_compressed_messages:
|
||||
|
|
@ -8554,6 +8583,7 @@ class AIAgent:
|
|||
# (e.g. repeated context-length errors that exhausted retry_count),
|
||||
# the `response` variable is still None. Break out cleanly.
|
||||
if response is None:
|
||||
_turn_exit_reason = "all_retries_exhausted_no_response"
|
||||
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
break
|
||||
|
|
@ -8960,13 +8990,34 @@ class AIAgent:
|
|||
# compaction fires, not the raw context window.
|
||||
# Does not inject into messages — just prints to CLI output
|
||||
# and fires status_callback for gateway platforms.
|
||||
# Tiered: 85% (orange) and 95% (red/critical).
|
||||
if _compressor.threshold_tokens > 0:
|
||||
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
||||
if _compaction_progress >= 0.85 and not self._context_pressure_warned:
|
||||
self._context_pressure_warned = True
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
# Determine the warning tier for this progress level
|
||||
_warn_tier = 0.0
|
||||
if _compaction_progress >= 0.95:
|
||||
_warn_tier = 0.95
|
||||
elif _compaction_progress >= 0.85:
|
||||
_warn_tier = 0.85
|
||||
if _warn_tier > self._context_pressure_warned_at:
|
||||
# Class-level dedup: check if this session was already
|
||||
# warned at this tier within the cooldown window.
|
||||
_sid = self.session_id or "default"
|
||||
_last = AIAgent._context_pressure_last_warned.get(_sid)
|
||||
_now = time.time()
|
||||
if _last is None or _last[0] < _warn_tier or (_now - _last[1]) >= self._CONTEXT_PRESSURE_COOLDOWN:
|
||||
self._context_pressure_warned_at = _warn_tier
|
||||
AIAgent._context_pressure_last_warned[_sid] = (_warn_tier, _now)
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
# Evict stale entries (older than 2x cooldown)
|
||||
_cutoff = _now - self._CONTEXT_PRESSURE_COOLDOWN * 2
|
||||
AIAgent._context_pressure_last_warned = {
|
||||
k: v for k, v in AIAgent._context_pressure_last_warned.items()
|
||||
if v[1] > _cutoff
|
||||
}
|
||||
|
||||
if self.compression_enabled and _compressor.should_compress(_real_tokens):
|
||||
self._safe_print(" ⟳ compacting context…")
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message,
|
||||
approx_tokens=self.context_compressor.last_prompt_tokens,
|
||||
|
|
@ -8996,6 +9047,7 @@ class AIAgent:
|
|||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
_turn_exit_reason = "fallback_prior_turn_content"
|
||||
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._last_content_with_tools = None
|
||||
self._empty_content_retries = 0
|
||||
|
|
@ -9041,8 +9093,28 @@ class AIAgent:
|
|||
self._save_session_log(messages)
|
||||
continue
|
||||
|
||||
# Exhausted prefill attempts or no structured
|
||||
# reasoning — fall through to "(empty)" terminal.
|
||||
# ── Empty response retry (no reasoning) ──────
|
||||
# Model returned nothing — no content, no
|
||||
# structured reasoning, no tool calls. Common
|
||||
# with open models (transient provider issues,
|
||||
# rate limits, sampling flukes). Silently retry
|
||||
# up to 3 times before giving up. Skip when
|
||||
# content has inline <think> tags (model chose
|
||||
# to reason, just no visible text).
|
||||
_truly_empty = not final_response.strip()
|
||||
if _truly_empty and not _has_structured and self._empty_content_retries < 3:
|
||||
self._empty_content_retries += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}↻ Empty response (no content or reasoning) "
|
||||
f"— retrying ({self._empty_content_retries}/3)",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
|
||||
# Exhausted prefill attempts, empty retries, or
|
||||
# structured reasoning with no content —
|
||||
# fall through to "(empty)" terminal.
|
||||
_turn_exit_reason = "empty_response_exhausted"
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
assistant_msg["content"] = "(empty)"
|
||||
|
|
@ -9052,7 +9124,7 @@ class AIAgent:
|
|||
reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Reasoning-only response (no visible content). Reasoning: {reasoning_preview}")
|
||||
else:
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning).")
|
||||
self._vprint(f"{self.log_prefix}ℹ️ Empty response (no content or reasoning) after 3 retries.")
|
||||
|
||||
final_response = "(empty)"
|
||||
break
|
||||
|
|
@ -9114,6 +9186,7 @@ class AIAgent:
|
|||
|
||||
messages.append(final_msg)
|
||||
|
||||
_turn_exit_reason = f"text_response(finish_reason={finish_reason})"
|
||||
if not self.quiet_mode:
|
||||
self._safe_print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
||||
break
|
||||
|
|
@ -9163,6 +9236,7 @@ class AIAgent:
|
|||
|
||||
# If we're near the limit, break to avoid infinite loops
|
||||
if api_call_count >= self.max_iterations - 1:
|
||||
_turn_exit_reason = f"error_near_max_iterations({error_msg[:80]})"
|
||||
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
|
||||
# Append as assistant so the history stays valid for
|
||||
# session resume (avoids consecutive user messages).
|
||||
|
|
@ -9173,6 +9247,7 @@ class AIAgent:
|
|||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
_turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})"
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)")
|
||||
final_response = self._handle_max_iterations(messages, api_call_count)
|
||||
|
|
@ -9189,6 +9264,49 @@ class AIAgent:
|
|||
# Persist session to both JSON log and SQLite
|
||||
self._persist_session(messages, conversation_history)
|
||||
|
||||
# ── Turn-exit diagnostic log ─────────────────────────────────────
|
||||
# Always logged at INFO so agent.log captures WHY every turn ended.
|
||||
# When the last message is a tool result (agent was mid-work), log
|
||||
# at WARNING — this is the "just stops" scenario users report.
|
||||
_last_msg_role = messages[-1].get("role") if messages else None
|
||||
_last_tool_name = None
|
||||
if _last_msg_role == "tool":
|
||||
# Walk back to find the assistant message with the tool call
|
||||
for _m in reversed(messages):
|
||||
if _m.get("role") == "assistant" and _m.get("tool_calls"):
|
||||
_tcs = _m["tool_calls"]
|
||||
if _tcs and isinstance(_tcs[0], dict):
|
||||
_last_tool_name = _tcs[-1].get("function", {}).get("name")
|
||||
break
|
||||
|
||||
_turn_tool_count = sum(
|
||||
1 for m in messages
|
||||
if isinstance(m, dict) and m.get("role") == "assistant" and m.get("tool_calls")
|
||||
)
|
||||
_resp_len = len(final_response) if final_response else 0
|
||||
_budget_used = self.iteration_budget.used if self.iteration_budget else 0
|
||||
_budget_max = self.iteration_budget.max_total if self.iteration_budget else 0
|
||||
|
||||
_diag_msg = (
|
||||
"Turn ended: reason=%s model=%s api_calls=%d/%d budget=%d/%d "
|
||||
"tool_turns=%d last_msg_role=%s response_len=%d session=%s"
|
||||
)
|
||||
_diag_args = (
|
||||
_turn_exit_reason, self.model, api_call_count, self.max_iterations,
|
||||
_budget_used, _budget_max,
|
||||
_turn_tool_count, _last_msg_role, _resp_len,
|
||||
self.session_id or "none",
|
||||
)
|
||||
|
||||
if _last_msg_role == "tool" and not interrupted:
|
||||
# Agent was mid-work — this is the "just stops" case.
|
||||
logger.warning(
|
||||
"Turn ended with pending tool result (agent may appear stuck). "
|
||||
+ _diag_msg + " last_tool=%s",
|
||||
*_diag_args, _last_tool_name,
|
||||
)
|
||||
else:
|
||||
logger.info(_diag_msg, *_diag_args)
|
||||
|
||||
# Plugin hook: post_llm_call
|
||||
# Fired once per turn after the tool-calling loop completes.
|
||||
|
|
|
|||
|
|
@ -77,6 +77,20 @@ class TestReadCodexAccessToken:
|
|||
result = _read_codex_access_token()
|
||||
assert result == "tok-123"
|
||||
|
||||
def test_pool_without_selected_entry_falls_back_to_auth_store(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
valid_jwt = "eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjk5OTk5OTk5OTl9.sig"
|
||||
with patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), \
|
||||
patch("hermes_cli.auth._read_codex_tokens", return_value={
|
||||
"tokens": {"access_token": valid_jwt, "refresh_token": "refresh"}
|
||||
}):
|
||||
result = _read_codex_access_token()
|
||||
|
||||
assert result == valid_jwt
|
||||
|
||||
def test_missing_returns_none(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -238,6 +252,24 @@ class TestAnthropicOAuthFlag:
|
|||
assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled"
|
||||
|
||||
|
||||
class TestTryCodex:
|
||||
def test_pool_without_selected_entry_falls_back_to_auth_store(self):
|
||||
with (
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)),
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-auth-token"),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_codex
|
||||
|
||||
client, model = _try_codex()
|
||||
|
||||
assert client is not None
|
||||
assert model == "gpt-5.2-codex"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
class TestExpiredCodexFallback:
|
||||
"""Test that expired Codex tokens don't block the auto chain."""
|
||||
|
||||
|
|
|
|||
|
|
@ -324,7 +324,10 @@ class TestCompressWithClient:
|
|||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Last head message (index 1) is "assistant" → summary should be "user"
|
||||
# Last head message (index 1) is "assistant" → summary should be "user".
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# head_last=assistant, tail_first=assistant → summary_role="user", no collision.
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
msgs = [
|
||||
{"role": "user", "content": "msg 0"},
|
||||
{"role": "assistant", "content": "msg 1"},
|
||||
|
|
@ -332,6 +335,8 @@ class TestCompressWithClient:
|
|||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
|
@ -460,8 +465,10 @@ class TestCompressWithClient:
|
|||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Head: [system, user] → last head = user
|
||||
# Tail: [assistant, user] → first tail = assistant
|
||||
# Tail: [assistant, user, assistant] → first tail = assistant
|
||||
# summary_role="assistant" collides with tail, "user" collides with head → merge
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
msgs = [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "msg 1"},
|
||||
|
|
@ -470,6 +477,7 @@ class TestCompressWithClient:
|
|||
{"role": "assistant", "content": "msg 4"}, # compressed
|
||||
{"role": "assistant", "content": "msg 5"}, # tail start
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
|
@ -481,7 +489,7 @@ class TestCompressWithClient:
|
|||
if r1 in ("user", "assistant") and r2 in ("user", "assistant"):
|
||||
assert r1 != r2, f"consecutive {r1} at indices {i-1},{i}"
|
||||
|
||||
# The summary should be merged into the first tail message (assistant)
|
||||
# The summary should be merged into the first tail message (assistant at index 5)
|
||||
first_tail = [m for m in result if "msg 5" in (m.get("content") or "")]
|
||||
assert len(first_tail) == 1
|
||||
assert "summary text" in first_tail[0]["content"]
|
||||
|
|
@ -496,14 +504,18 @@ class TestCompressWithClient:
|
|||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Head=assistant, Tail=assistant → summary_role="user", no collision
|
||||
# Head=assistant, Tail=assistant → summary_role="user", no collision.
|
||||
# With min_tail=3, tail = last 3 messages (indices 5-7).
|
||||
# Need 8 messages: min_for_compress = 2+3+1 = 6, must have > 6.
|
||||
msgs = [
|
||||
{"role": "user", "content": "msg 0"},
|
||||
{"role": "assistant", "content": "msg 1"},
|
||||
{"role": "user", "content": "msg 2"},
|
||||
{"role": "assistant", "content": "msg 3"},
|
||||
{"role": "assistant", "content": "msg 4"},
|
||||
{"role": "user", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
|
@ -600,3 +612,158 @@ class TestSummaryTargetRatio:
|
|||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
assert c.protect_last_n == 20
|
||||
|
||||
|
||||
class TestTokenBudgetTailProtection:
|
||||
"""Tests for token-budget-based tail protection (PR #6240).
|
||||
|
||||
The core change: tail protection is now based on a token budget rather
|
||||
than a fixed message count. This prevents large tool outputs from
|
||||
blocking compaction.
|
||||
"""
|
||||
|
||||
@pytest.fixture()
|
||||
def budget_compressor(self):
|
||||
"""Compressor with known token budget for tail protection tests."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
|
||||
c = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.50, # 100K threshold
|
||||
protect_first_n=2,
|
||||
protect_last_n=20,
|
||||
quiet_mode=True,
|
||||
)
|
||||
return c
|
||||
|
||||
def test_large_tool_outputs_no_longer_block_compaction(self, budget_compressor):
|
||||
"""The motivating scenario: 20 messages with large tool outputs should
|
||||
NOT prevent compaction. With message-count tail protection they would
|
||||
all be protected, leaving nothing to summarize."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "Start task"},
|
||||
{"role": "assistant", "content": "On it"},
|
||||
]
|
||||
# Add 20 messages with large tool outputs (~5K chars each ≈ 1250 tokens)
|
||||
for i in range(10):
|
||||
messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": f"tool_{i}", "arguments": "{}"}}],
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool", "content": "x" * 5000,
|
||||
"tool_call_id": f"call_{i}",
|
||||
})
|
||||
# Add 3 recent small messages
|
||||
messages.append({"role": "user", "content": "What's the status?"})
|
||||
messages.append({"role": "assistant", "content": "Here's what I found..."})
|
||||
messages.append({"role": "user", "content": "Continue"})
|
||||
|
||||
# The tail cut should NOT protect all 20 tool messages
|
||||
head_end = c.protect_first_n
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
tail_size = len(messages) - cut
|
||||
# With token budget, the tail should be much smaller than 20+
|
||||
assert tail_size < 20, f"Tail {tail_size} messages — large tool outputs are blocking compaction"
|
||||
# But at least 3 (hard minimum)
|
||||
assert tail_size >= 3
|
||||
|
||||
def test_min_tail_always_3_messages(self, budget_compressor):
|
||||
"""Even with a tiny token budget, at least 3 messages are protected."""
|
||||
c = budget_compressor
|
||||
# Override to a tiny budget
|
||||
c.tail_token_budget = 10
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "working on it"},
|
||||
{"role": "user", "content": "more work"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
{"role": "user", "content": "thanks"},
|
||||
]
|
||||
head_end = 2
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
tail_size = len(messages) - cut
|
||||
assert tail_size >= 3, f"Tail is only {tail_size} messages, min should be 3"
|
||||
|
||||
def test_soft_ceiling_allows_oversized_message(self, budget_compressor):
|
||||
"""The 1.5x soft ceiling allows an oversized message to be included
|
||||
rather than splitting it."""
|
||||
c = budget_compressor
|
||||
# Set a small budget — 500 tokens
|
||||
c.tail_token_budget = 500
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "read the file"},
|
||||
# This message is ~600 tokens (> budget of 500, but < 1.5x = 750)
|
||||
{"role": "assistant", "content": "a" * 2400},
|
||||
{"role": "user", "content": "short"},
|
||||
{"role": "assistant", "content": "short reply"},
|
||||
{"role": "user", "content": "continue"},
|
||||
]
|
||||
head_end = 2
|
||||
cut = c._find_tail_cut_by_tokens(messages, head_end)
|
||||
# The oversized message at index 3 should NOT be the cut point
|
||||
# because 1.5x ceiling = 750 tokens and accumulated would be ~610
|
||||
# (short msgs + oversized msg) which is < 750
|
||||
tail_size = len(messages) - cut
|
||||
assert tail_size >= 3
|
||||
|
||||
def test_small_conversation_still_compresses(self, budget_compressor):
|
||||
"""With the new min of 8 messages (head=2 + 3 + 1 guard + 2 middle),
|
||||
a small but compressible conversation should still compress."""
|
||||
c = budget_compressor
|
||||
# 9 messages: head(2) + 4 middle + 3 tail = compressible
|
||||
messages = []
|
||||
for i in range(9):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append({"role": role, "content": f"Message {i}"})
|
||||
|
||||
# Should not early-return (needs > protect_first_n + 3 + 1 = 6)
|
||||
# Mock the summary generation to avoid real API call
|
||||
with patch.object(c, "_generate_summary", return_value="Summary of conversation"):
|
||||
result = c.compress(messages, current_tokens=90_000)
|
||||
# Should have compressed (fewer messages than original)
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_prune_with_token_budget(self, budget_compressor):
|
||||
"""_prune_old_tool_results with protect_tail_tokens respects the budget."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "start"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "big.txt"}'}}]},
|
||||
{"role": "tool", "content": "x" * 10000, "tool_call_id": "c1"}, # ~2500 tokens
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "read_file", "arguments": '{"path": "small.txt"}'}}]},
|
||||
{"role": "tool", "content": "y" * 10000, "tool_call_id": "c2"}, # ~2500 tokens
|
||||
{"role": "user", "content": "short recent message"},
|
||||
{"role": "assistant", "content": "short reply"},
|
||||
]
|
||||
# With a 1000-token budget, only the last couple messages should be protected
|
||||
result, pruned = c._prune_old_tool_results(
|
||||
messages, protect_tail_count=2, protect_tail_tokens=1000,
|
||||
)
|
||||
# At least one old tool result should have been pruned
|
||||
assert pruned >= 1
|
||||
|
||||
def test_prune_without_token_budget_uses_message_count(self, budget_compressor):
|
||||
"""Without protect_tail_tokens, falls back to message-count behavior."""
|
||||
c = budget_compressor
|
||||
messages = [
|
||||
{"role": "user", "content": "start"},
|
||||
{"role": "assistant", "content": None,
|
||||
"tool_calls": [{"function": {"name": "tool", "arguments": "{}"}}]},
|
||||
{"role": "tool", "content": "x" * 5000, "tool_call_id": "c1"},
|
||||
{"role": "user", "content": "recent"},
|
||||
{"role": "assistant", "content": "reply"},
|
||||
]
|
||||
# protect_tail_count=3 means last 3 messages protected
|
||||
result, pruned = c._prune_old_tool_results(
|
||||
messages, protect_tail_count=3,
|
||||
)
|
||||
# Tool at index 2 is outside the protected tail (last 3 = indices 2,3,4)
|
||||
# so it might or might not be pruned depending on boundary
|
||||
assert isinstance(pruned, int)
|
||||
|
|
|
|||
|
|
@ -214,6 +214,42 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch):
|
|||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_exhausted_402_entry_resets_after_one_hour(tmp_path, monkeypatch):
|
||||
"""402-exhausted credentials recover after 1 hour, not 24."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": time.time() - 3700, # ~1h2m ago
|
||||
"last_error_code": 402,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.id == "cred-1"
|
||||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
|
|
|
|||
750
tests/agent/test_error_classifier.py
Normal file
750
tests/agent/test_error_classifier.py
Normal file
|
|
@ -0,0 +1,750 @@
|
|||
"""Tests for agent.error_classifier — structured API error classification."""
|
||||
|
||||
import pytest
|
||||
from agent.error_classifier import (
|
||||
ClassifiedError,
|
||||
FailoverReason,
|
||||
classify_api_error,
|
||||
_extract_status_code,
|
||||
_extract_error_body,
|
||||
_extract_error_code,
|
||||
_classify_402,
|
||||
)
|
||||
|
||||
|
||||
# ── Helper: mock API errors ────────────────────────────────────────────
|
||||
|
||||
class MockAPIError(Exception):
|
||||
"""Simulates an OpenAI SDK APIStatusError."""
|
||||
def __init__(self, message, status_code=None, body=None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class MockTransportError(Exception):
|
||||
"""Simulates a transport-level error with a specific type name."""
|
||||
pass
|
||||
|
||||
|
||||
class ReadTimeout(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class RemoteProtocolError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
class ServerDisconnectedError(MockTransportError):
|
||||
pass
|
||||
|
||||
|
||||
# ── Test: FailoverReason enum ──────────────────────────────────────────
|
||||
|
||||
class TestFailoverReason:
|
||||
def test_all_reasons_have_string_values(self):
|
||||
for reason in FailoverReason:
|
||||
assert isinstance(reason.value, str)
|
||||
|
||||
def test_enum_members_exist(self):
|
||||
expected = {
|
||||
"auth", "auth_permanent", "billing", "rate_limit",
|
||||
"overloaded", "server_error", "timeout",
|
||||
"context_overflow", "payload_too_large",
|
||||
"model_not_found", "format_error",
|
||||
"thinking_signature", "long_context_tier", "unknown",
|
||||
}
|
||||
actual = {r.value for r in FailoverReason}
|
||||
assert expected == actual
|
||||
|
||||
|
||||
# ── Test: ClassifiedError ──────────────────────────────────────────────
|
||||
|
||||
class TestClassifiedError:
|
||||
def test_is_auth_property(self):
|
||||
e1 = ClassifiedError(reason=FailoverReason.auth)
|
||||
assert e1.is_auth is True
|
||||
|
||||
e2 = ClassifiedError(reason=FailoverReason.auth_permanent)
|
||||
assert e2.is_auth is True
|
||||
|
||||
e3 = ClassifiedError(reason=FailoverReason.billing)
|
||||
assert e3.is_auth is False
|
||||
|
||||
def test_is_transient_property(self):
|
||||
transient_reasons = [
|
||||
FailoverReason.rate_limit,
|
||||
FailoverReason.overloaded,
|
||||
FailoverReason.server_error,
|
||||
FailoverReason.timeout,
|
||||
FailoverReason.unknown,
|
||||
]
|
||||
for reason in transient_reasons:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is True, f"{reason} should be transient"
|
||||
|
||||
non_transient = [
|
||||
FailoverReason.auth,
|
||||
FailoverReason.billing,
|
||||
FailoverReason.model_not_found,
|
||||
FailoverReason.format_error,
|
||||
]
|
||||
for reason in non_transient:
|
||||
e = ClassifiedError(reason=reason)
|
||||
assert e.is_transient is False, f"{reason} should NOT be transient"
|
||||
|
||||
def test_defaults(self):
|
||||
e = ClassifiedError(reason=FailoverReason.unknown)
|
||||
assert e.retryable is True
|
||||
assert e.should_compress is False
|
||||
assert e.should_rotate_credential is False
|
||||
assert e.should_fallback is False
|
||||
assert e.status_code is None
|
||||
assert e.message == ""
|
||||
|
||||
|
||||
# ── Test: Status code extraction ───────────────────────────────────────
|
||||
|
||||
class TestExtractStatusCode:
|
||||
def test_from_status_code_attr(self):
|
||||
e = MockAPIError("fail", status_code=429)
|
||||
assert _extract_status_code(e) == 429
|
||||
|
||||
def test_from_status_attr(self):
|
||||
class ErrWithStatus(Exception):
|
||||
status = 503
|
||||
assert _extract_status_code(ErrWithStatus()) == 503
|
||||
|
||||
def test_from_cause_chain(self):
|
||||
inner = MockAPIError("inner", status_code=401)
|
||||
outer = Exception("outer")
|
||||
outer.__cause__ = inner
|
||||
assert _extract_status_code(outer) == 401
|
||||
|
||||
def test_none_when_missing(self):
|
||||
assert _extract_status_code(Exception("generic")) is None
|
||||
|
||||
def test_rejects_non_http_status(self):
|
||||
"""Integers outside 100-599 on .status should be ignored."""
|
||||
class ErrWeirdStatus(Exception):
|
||||
status = 42
|
||||
assert _extract_status_code(ErrWeirdStatus()) is None
|
||||
|
||||
|
||||
# ── Test: Error body extraction ────────────────────────────────────────
|
||||
|
||||
class TestExtractErrorBody:
|
||||
def test_from_body_attr(self):
|
||||
e = MockAPIError("fail", body={"error": {"message": "bad"}})
|
||||
assert _extract_error_body(e) == {"error": {"message": "bad"}}
|
||||
|
||||
def test_empty_when_no_body(self):
|
||||
assert _extract_error_body(Exception("generic")) == {}
|
||||
|
||||
|
||||
# ── Test: Error code extraction ────────────────────────────────────────
|
||||
|
||||
class TestExtractErrorCode:
|
||||
def test_from_nested_error_code(self):
|
||||
body = {"error": {"code": "rate_limit_exceeded"}}
|
||||
assert _extract_error_code(body) == "rate_limit_exceeded"
|
||||
|
||||
def test_from_nested_error_type(self):
|
||||
body = {"error": {"type": "invalid_request_error"}}
|
||||
assert _extract_error_code(body) == "invalid_request_error"
|
||||
|
||||
def test_from_top_level_code(self):
|
||||
body = {"code": "model_not_found"}
|
||||
assert _extract_error_code(body) == "model_not_found"
|
||||
|
||||
def test_empty_when_no_code(self):
|
||||
assert _extract_error_code({}) == ""
|
||||
assert _extract_error_code({"error": {"message": "oops"}}) == ""
|
||||
|
||||
|
||||
# ── Test: 402 disambiguation ───────────────────────────────────────────
|
||||
|
||||
class TestClassify402:
|
||||
"""The critical 402 billing vs rate_limit disambiguation."""
|
||||
|
||||
def test_billing_exhaustion(self):
|
||||
"""Plain 402 = billing."""
|
||||
result = _classify_402(
|
||||
"payment required",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
def test_transient_usage_limit(self):
|
||||
"""402 with 'usage limit' + 'try again' = rate limit, not billing."""
|
||||
result = _classify_402(
|
||||
"usage limit exceeded. try again in 5 minutes",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_rotate_credential is True
|
||||
|
||||
def test_quota_with_retry(self):
|
||||
"""402 with 'quota' + 'retry' = rate limit."""
|
||||
result = _classify_402(
|
||||
"quota exceeded, please retry after the window resets",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_quota_without_retry(self):
|
||||
"""402 with just 'quota' but no transient signal = billing."""
|
||||
result = _classify_402(
|
||||
"quota exceeded",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_insufficient_credits(self):
|
||||
result = _classify_402(
|
||||
"insufficient credits to complete request",
|
||||
lambda reason, **kw: ClassifiedError(reason=reason, **kw),
|
||||
)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
|
||||
# ── Test: Full classification pipeline ─────────────────────────────────
|
||||
|
||||
class TestClassifyApiError:
|
||||
"""End-to-end classification tests."""
|
||||
|
||||
# ── Auth errors ──
|
||||
|
||||
def test_401_classified_as_auth(self):
|
||||
e = MockAPIError("Unauthorized", status_code=401)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.auth
|
||||
assert result.should_rotate_credential is True
|
||||
# 401 is non-retryable on its own — credential rotation runs
|
||||
# before the retryability check in the agent loop.
|
||||
assert result.retryable is False
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_classified_as_auth(self):
|
||||
e = MockAPIError("Forbidden", status_code=403)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.auth
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_key_limit_classified_as_billing(self):
|
||||
"""OpenRouter 403 'key limit exceeded' is billing, not auth."""
|
||||
e = MockAPIError("Key limit exceeded for this key", status_code=403)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.should_rotate_credential is True
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_403_spending_limit_classified_as_billing(self):
|
||||
e = MockAPIError("spending limit reached", status_code=403)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
# ── Billing ──
|
||||
|
||||
def test_402_plain_billing(self):
|
||||
e = MockAPIError("Payment Required", status_code=402)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
assert result.retryable is False
|
||||
|
||||
def test_402_transient_usage_limit(self):
|
||||
e = MockAPIError("usage limit exceeded, try again later", status_code=402)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Rate limit ──
|
||||
|
||||
def test_429_rate_limit(self):
|
||||
e = MockAPIError("Too Many Requests", status_code=429)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
assert result.should_fallback is True
|
||||
|
||||
# ── Server errors ──
|
||||
|
||||
def test_500_server_error(self):
|
||||
e = MockAPIError("Internal Server Error", status_code=500)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
assert result.retryable is True
|
||||
|
||||
def test_502_server_error(self):
|
||||
e = MockAPIError("Bad Gateway", status_code=502)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_503_overloaded(self):
|
||||
e = MockAPIError("Service Unavailable", status_code=503)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.overloaded
|
||||
|
||||
def test_529_anthropic_overloaded(self):
|
||||
e = MockAPIError("Overloaded", status_code=529)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.overloaded
|
||||
|
||||
# ── Model not found ──
|
||||
|
||||
def test_404_model_not_found(self):
|
||||
e = MockAPIError("model not found", status_code=404)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
assert result.should_fallback is True
|
||||
assert result.retryable is False
|
||||
|
||||
def test_404_generic(self):
|
||||
e = MockAPIError("Not Found", status_code=404)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
# ── Payload too large ──
|
||||
|
||||
def test_413_payload_too_large(self):
|
||||
e = MockAPIError("Request Entity Too Large", status_code=413)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.payload_too_large
|
||||
assert result.should_compress is True
|
||||
|
||||
# ── Context overflow ──
|
||||
|
||||
def test_400_context_length(self):
|
||||
e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_400_too_many_tokens(self):
|
||||
e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_prompt_too_long(self):
|
||||
e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_generic_large_session(self):
|
||||
"""Generic 400 with large session → context overflow heuristic."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=100000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_400_generic_small_session_is_format_error(self):
|
||||
"""Generic 400 with small session → format error, not context overflow."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=1000, context_length=200000)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
|
||||
# ── Server disconnect + large session ──
|
||||
|
||||
def test_disconnect_large_session_context_overflow(self):
|
||||
"""Server disconnect with large session → context overflow."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_disconnect_small_session_timeout(self):
|
||||
"""Server disconnect with small session → timeout."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=5000, context_length=200000)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
# ── Provider-specific: Anthropic thinking signature ──
|
||||
|
||||
def test_anthropic_thinking_signature(self):
|
||||
e = MockAPIError(
|
||||
"thinking block has invalid signature",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.thinking_signature
|
||||
assert result.retryable is True
|
||||
|
||||
def test_non_anthropic_400_with_signature_not_classified_as_thinking(self):
|
||||
"""400 with 'signature' but from non-Anthropic → format error."""
|
||||
e = MockAPIError("invalid signature", status_code=400)
|
||||
result = classify_api_error(e, provider="openrouter", approx_tokens=0)
|
||||
# Without "thinking" in the message, it shouldn't be thinking_signature
|
||||
assert result.reason != FailoverReason.thinking_signature
|
||||
|
||||
# ── Provider-specific: Anthropic long-context tier ──
|
||||
|
||||
def test_anthropic_long_context_tier(self):
|
||||
e = MockAPIError(
|
||||
"Extra usage is required for long context requests over 200k tokens",
|
||||
status_code=429,
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4")
|
||||
assert result.reason == FailoverReason.long_context_tier
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_normal_429_not_long_context(self):
|
||||
"""Normal 429 without 'extra usage' + 'long context' → rate_limit."""
|
||||
e = MockAPIError("Too Many Requests", status_code=429)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
# ── Transport errors ──
|
||||
|
||||
def test_read_timeout(self):
|
||||
e = ReadTimeout("Read timed out")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
assert result.retryable is True
|
||||
|
||||
def test_connect_error(self):
|
||||
e = ConnectError("Connection refused")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_connection_error_builtin(self):
|
||||
e = ConnectionError("Connection reset by peer")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_timeout_error_builtin(self):
|
||||
e = TimeoutError("timed out")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
# ── Error code classification ──
|
||||
|
||||
def test_error_code_resource_exhausted(self):
|
||||
e = MockAPIError(
|
||||
"Resource exhausted",
|
||||
body={"error": {"code": "resource_exhausted", "message": "Too many requests"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_error_code_model_not_found(self):
|
||||
e = MockAPIError(
|
||||
"Model not available",
|
||||
body={"error": {"code": "model_not_found"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
def test_error_code_context_length_exceeded(self):
|
||||
e = MockAPIError(
|
||||
"Context too large",
|
||||
body={"error": {"code": "context_length_exceeded"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Message-only patterns (no status code) ──
|
||||
|
||||
def test_message_billing_pattern(self):
|
||||
e = Exception("insufficient credits to complete this request")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_message_rate_limit_pattern(self):
|
||||
e = Exception("rate limit reached for this model")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_message_auth_pattern(self):
|
||||
e = Exception("invalid api key provided")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.auth
|
||||
|
||||
def test_message_model_not_found_pattern(self):
|
||||
e = Exception("gpt-99 is not a valid model")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
|
||||
def test_message_context_overflow_pattern(self):
|
||||
e = Exception("maximum context length exceeded")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Unknown / fallback ──
|
||||
|
||||
def test_generic_exception_is_unknown(self):
|
||||
e = Exception("something weird happened")
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.unknown
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Format error ──
|
||||
|
||||
def test_400_descriptive_format_error(self):
|
||||
"""400 with descriptive message (not context overflow) → format error."""
|
||||
e = MockAPIError(
|
||||
"Invalid value for parameter 'temperature': must be between 0 and 2",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}},
|
||||
)
|
||||
result = classify_api_error(e, approx_tokens=1000)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_422_format_error(self):
|
||||
e = MockAPIError("Unprocessable Entity", status_code=422)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
# ── Peer closed + large session ──
|
||||
|
||||
def test_peer_closed_large_session(self):
|
||||
e = Exception("peer closed connection without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=130000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Chinese error messages ──
|
||||
|
||||
def test_chinese_context_overflow(self):
|
||||
e = MockAPIError("超过最大长度限制", status_code=400)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
# ── Result metadata ──
|
||||
|
||||
def test_provider_and_model_in_result(self):
|
||||
e = MockAPIError("fail", status_code=500)
|
||||
result = classify_api_error(e, provider="openrouter", model="gpt-5")
|
||||
assert result.provider == "openrouter"
|
||||
assert result.model == "gpt-5"
|
||||
assert result.status_code == 500
|
||||
|
||||
def test_message_extracted(self):
|
||||
e = MockAPIError(
|
||||
"outer",
|
||||
status_code=500,
|
||||
body={"error": {"message": "Internal server error occurred"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.message == "Internal server error occurred"
|
||||
|
||||
|
||||
# ── Test: Adversarial / edge cases (from live testing) ─────────────────
|
||||
|
||||
class TestAdversarialEdgeCases:
|
||||
"""Edge cases discovered during live testing with real SDK objects."""
|
||||
|
||||
def test_empty_exception_message(self):
|
||||
result = classify_api_error(Exception(""))
|
||||
assert result.reason == FailoverReason.unknown
|
||||
assert result.retryable is True
|
||||
|
||||
def test_500_with_none_body(self):
|
||||
e = MockAPIError("fail", status_code=500, body=None)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_non_dict_body(self):
|
||||
"""Some providers return strings instead of JSON."""
|
||||
class StringBodyError(Exception):
|
||||
status_code = 400
|
||||
body = "just a string"
|
||||
result = classify_api_error(StringBodyError("bad"))
|
||||
assert result.reason == FailoverReason.format_error
|
||||
|
||||
def test_list_body(self):
|
||||
class ListBodyError(Exception):
|
||||
status_code = 500
|
||||
body = [{"error": "something"}]
|
||||
result = classify_api_error(ListBodyError("server error"))
|
||||
assert result.reason == FailoverReason.server_error
|
||||
|
||||
def test_circular_cause_chain(self):
|
||||
"""Must not infinite-loop on circular __cause__."""
|
||||
e = Exception("circular")
|
||||
e.__cause__ = e
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.unknown
|
||||
|
||||
def test_three_level_cause_chain(self):
|
||||
inner = MockAPIError("inner", status_code=429)
|
||||
middle = Exception("middle")
|
||||
middle.__cause__ = inner
|
||||
outer = RuntimeError("outer")
|
||||
outer.__cause__ = middle
|
||||
result = classify_api_error(outer)
|
||||
assert result.status_code == 429
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_400_with_rate_limit_text(self):
|
||||
"""Some providers send rate limits as 400 instead of 429."""
|
||||
e = MockAPIError(
|
||||
"rate limit policy",
|
||||
status_code=400,
|
||||
body={"error": {"message": "rate limit exceeded on this model"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_400_with_billing_text(self):
|
||||
"""Some providers send billing errors as 400."""
|
||||
e = MockAPIError(
|
||||
"billing",
|
||||
status_code=400,
|
||||
body={"error": {"message": "insufficient credits for this request"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_200_with_error_body(self):
|
||||
"""200 status with error in body — should be unknown, not crash."""
|
||||
class WeirdSuccess(Exception):
|
||||
status_code = 200
|
||||
body = {"error": {"message": "loading"}}
|
||||
result = classify_api_error(WeirdSuccess("model loading"))
|
||||
assert result.reason == FailoverReason.unknown
|
||||
|
||||
def test_ollama_context_size_exceeded(self):
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "context size has been exceeded"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="ollama")
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_connection_refused_error(self):
|
||||
e = ConnectionRefusedError("Connection refused: localhost:11434")
|
||||
result = classify_api_error(e, provider="ollama")
|
||||
assert result.reason == FailoverReason.timeout
|
||||
|
||||
def test_body_message_enrichment(self):
|
||||
"""Body message must be included in pattern matching even when
|
||||
str(error) doesn't contain it (OpenAI SDK APIStatusError)."""
|
||||
e = MockAPIError(
|
||||
"Usage limit", # str(e) = "usage limit"
|
||||
status_code=402,
|
||||
body={"error": {"message": "Usage limit reached, try again in 5 minutes"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
# "try again" is only in body, not in str(e)
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_disconnect_pattern_ordering(self):
|
||||
"""Disconnect + large session must beat generic transport catch."""
|
||||
class FakeRemoteProtocol(Exception):
|
||||
pass
|
||||
# Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern
|
||||
e = Exception("peer closed connection without sending complete message")
|
||||
result = classify_api_error(e, approx_tokens=150000, context_length=200000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_credit_balance_too_low(self):
|
||||
e = MockAPIError(
|
||||
"Credits low",
|
||||
status_code=402,
|
||||
body={"error": {"message": "Your credit balance is too low"}},
|
||||
)
|
||||
result = classify_api_error(e, provider="anthropic")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_deepseek_402_chinese(self):
|
||||
"""Chinese billing message should still match billing patterns."""
|
||||
# "余额不足" doesn't match English billing patterns, but 402 defaults to billing
|
||||
e = MockAPIError("余额不足", status_code=402)
|
||||
result = classify_api_error(e, provider="deepseek")
|
||||
assert result.reason == FailoverReason.billing
|
||||
|
||||
def test_openrouter_wrapped_context_overflow_in_metadata_raw(self):
|
||||
"""OpenRouter wraps provider errors in metadata.raw JSON string."""
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"code": 400,
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter", approx_tokens=10000)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
assert result.should_compress is True
|
||||
|
||||
def test_openrouter_wrapped_rate_limit_in_metadata_raw(self):
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.rate_limit
|
||||
|
||||
def test_thinking_signature_via_openrouter(self):
|
||||
"""Thinking signature errors proxied through OpenRouter must be caught."""
|
||||
e = MockAPIError(
|
||||
"thinking block has invalid signature",
|
||||
status_code=400,
|
||||
)
|
||||
# provider is openrouter, not anthropic — old code missed this
|
||||
result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4")
|
||||
assert result.reason == FailoverReason.thinking_signature
|
||||
|
||||
def test_generic_400_large_by_message_count(self):
|
||||
"""Many small messages (>80) should trigger context overflow heuristic."""
|
||||
e = MockAPIError(
|
||||
"Error",
|
||||
status_code=400,
|
||||
body={"error": {"message": "Error"}},
|
||||
)
|
||||
# Low token count but high message count
|
||||
result = classify_api_error(
|
||||
e, approx_tokens=5000, context_length=200000, num_messages=100,
|
||||
)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_disconnect_large_by_message_count(self):
|
||||
"""Server disconnect with 200+ messages should trigger context overflow."""
|
||||
e = Exception("server disconnected without sending complete message")
|
||||
result = classify_api_error(
|
||||
e, approx_tokens=5000, context_length=200000, num_messages=250,
|
||||
)
|
||||
assert result.reason == FailoverReason.context_overflow
|
||||
|
||||
def test_openrouter_wrapped_model_not_found_in_metadata_raw(self):
|
||||
e = MockAPIError(
|
||||
"Provider returned error",
|
||||
status_code=400,
|
||||
body={
|
||||
"error": {
|
||||
"message": "Provider returned error",
|
||||
"metadata": {
|
||||
"raw": '{"error":{"message":"The model gpt-99 does not exist"}}'
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e, provider="openrouter")
|
||||
assert result.reason == FailoverReason.model_not_found
|
||||
212
tests/agent/test_rate_limit_tracker.py
Normal file
212
tests/agent/test_rate_limit_tracker.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
"""Tests for agent.rate_limit_tracker — header parsing and formatting."""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from agent.rate_limit_tracker import (
|
||||
RateLimitBucket,
|
||||
RateLimitState,
|
||||
parse_rate_limit_headers,
|
||||
format_rate_limit_display,
|
||||
format_rate_limit_compact,
|
||||
_fmt_count,
|
||||
_fmt_seconds,
|
||||
_bar,
|
||||
)
|
||||
|
||||
|
||||
# ── Sample headers from Nous inference API ──────────────────────────────
|
||||
|
||||
NOUS_HEADERS = {
|
||||
"x-ratelimit-limit-requests": "800",
|
||||
"x-ratelimit-limit-requests-1h": "33600",
|
||||
"x-ratelimit-limit-tokens": "8000000",
|
||||
"x-ratelimit-limit-tokens-1h": "336000000",
|
||||
"x-ratelimit-remaining-requests": "795",
|
||||
"x-ratelimit-remaining-requests-1h": "33590",
|
||||
"x-ratelimit-remaining-tokens": "7999500",
|
||||
"x-ratelimit-remaining-tokens-1h": "335999000",
|
||||
"x-ratelimit-reset-requests": "45.5",
|
||||
"x-ratelimit-reset-requests-1h": "3500.0",
|
||||
"x-ratelimit-reset-tokens": "42.3",
|
||||
"x-ratelimit-reset-tokens-1h": "3490.0",
|
||||
}
|
||||
|
||||
|
||||
class TestParseHeaders:
|
||||
def test_basic_parsing(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
assert state is not None
|
||||
assert state.provider == "nous"
|
||||
assert state.has_data
|
||||
|
||||
assert state.requests_min.limit == 800
|
||||
assert state.requests_min.remaining == 795
|
||||
assert state.requests_min.reset_seconds == 45.5
|
||||
|
||||
assert state.requests_hour.limit == 33600
|
||||
assert state.requests_hour.remaining == 33590
|
||||
|
||||
assert state.tokens_min.limit == 8000000
|
||||
assert state.tokens_min.remaining == 7999500
|
||||
|
||||
assert state.tokens_hour.limit == 336000000
|
||||
assert state.tokens_hour.remaining == 335999000
|
||||
assert state.tokens_hour.reset_seconds == 3490.0
|
||||
|
||||
def test_no_headers(self):
|
||||
state = parse_rate_limit_headers({})
|
||||
assert state is None
|
||||
|
||||
def test_partial_headers(self):
|
||||
headers = {
|
||||
"x-ratelimit-limit-requests": "100",
|
||||
"x-ratelimit-remaining-requests": "50",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 100
|
||||
assert state.requests_min.remaining == 50
|
||||
# Missing fields default to 0
|
||||
assert state.tokens_min.limit == 0
|
||||
|
||||
def test_non_rate_limit_headers_ignored(self):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"server": "nginx",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is None
|
||||
|
||||
def test_malformed_values(self):
|
||||
headers = {
|
||||
"x-ratelimit-limit-requests": "not-a-number",
|
||||
"x-ratelimit-remaining-requests": "",
|
||||
"x-ratelimit-reset-requests": "abc",
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 0
|
||||
assert state.requests_min.remaining == 0
|
||||
assert state.requests_min.reset_seconds == 0.0
|
||||
|
||||
|
||||
class TestBucket:
|
||||
def test_used(self):
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=45.0, captured_at=time.time())
|
||||
assert b.used == 5
|
||||
|
||||
def test_usage_pct(self):
|
||||
b = RateLimitBucket(limit=100, remaining=20, reset_seconds=30.0, captured_at=time.time())
|
||||
assert b.usage_pct == pytest.approx(80.0)
|
||||
|
||||
def test_usage_pct_zero_limit(self):
|
||||
b = RateLimitBucket(limit=0, remaining=0)
|
||||
assert b.usage_pct == 0.0
|
||||
|
||||
def test_remaining_seconds_now(self):
|
||||
now = time.time()
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=60.0, captured_at=now - 10)
|
||||
# ~50 seconds should remain
|
||||
assert 49 <= b.remaining_seconds_now <= 51
|
||||
|
||||
def test_remaining_seconds_expired(self):
|
||||
b = RateLimitBucket(limit=800, remaining=795, reset_seconds=30.0, captured_at=time.time() - 60)
|
||||
assert b.remaining_seconds_now == 0.0
|
||||
|
||||
|
||||
class TestFormatting:
|
||||
def test_fmt_count_millions(self):
|
||||
assert _fmt_count(8000000) == "8.0M"
|
||||
assert _fmt_count(336000000) == "336.0M"
|
||||
|
||||
def test_fmt_count_thousands(self):
|
||||
assert _fmt_count(33600) == "33.6K"
|
||||
assert _fmt_count(1500) == "1.5K"
|
||||
|
||||
def test_fmt_count_small(self):
|
||||
assert _fmt_count(800) == "800"
|
||||
assert _fmt_count(0) == "0"
|
||||
|
||||
def test_fmt_seconds_short(self):
|
||||
assert _fmt_seconds(45) == "45s"
|
||||
assert _fmt_seconds(0) == "0s"
|
||||
|
||||
def test_fmt_seconds_minutes(self):
|
||||
assert _fmt_seconds(125) == "2m 5s"
|
||||
assert _fmt_seconds(120) == "2m"
|
||||
|
||||
def test_fmt_seconds_hours(self):
|
||||
assert _fmt_seconds(3660) == "1h 1m"
|
||||
assert _fmt_seconds(3600) == "1h"
|
||||
|
||||
def test_bar(self):
|
||||
bar = _bar(50.0, width=10)
|
||||
assert bar == "[█████░░░░░]"
|
||||
assert _bar(0.0, width=10) == "[░░░░░░░░░░]"
|
||||
assert _bar(100.0, width=10) == "[██████████]"
|
||||
|
||||
def test_format_display_no_data(self):
|
||||
state = RateLimitState()
|
||||
result = format_rate_limit_display(state)
|
||||
assert "No rate limit data" in result
|
||||
|
||||
def test_format_display_with_data(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
result = format_rate_limit_display(state)
|
||||
assert "Nous" in result
|
||||
assert "Requests/min" in result
|
||||
assert "Requests/hr" in result
|
||||
assert "Tokens/min" in result
|
||||
assert "Tokens/hr" in result
|
||||
assert "resets in" in result
|
||||
|
||||
def test_format_display_warning_on_high_usage(self):
|
||||
headers = {
|
||||
**NOUS_HEADERS,
|
||||
"x-ratelimit-remaining-requests": "50", # 750/800 used = 93.75%
|
||||
}
|
||||
state = parse_rate_limit_headers(headers)
|
||||
result = format_rate_limit_display(state)
|
||||
assert "⚠" in result
|
||||
|
||||
def test_format_compact(self):
|
||||
state = parse_rate_limit_headers(NOUS_HEADERS, provider="nous")
|
||||
result = format_rate_limit_compact(state)
|
||||
assert "RPM:" in result
|
||||
assert "RPH:" in result
|
||||
assert "TPM:" in result
|
||||
assert "TPH:" in result
|
||||
assert "resets" in result
|
||||
|
||||
def test_format_compact_no_data(self):
|
||||
state = RateLimitState()
|
||||
result = format_rate_limit_compact(state)
|
||||
assert "No rate limit data" in result
|
||||
|
||||
|
||||
class TestAgentIntegration:
|
||||
"""Test that AIAgent captures rate limit state correctly."""
|
||||
|
||||
def test_capture_rate_limits_from_headers(self):
|
||||
"""Simulate the header capture path without a real API call."""
|
||||
import sys
|
||||
import os
|
||||
# Use a mock httpx-like response
|
||||
class MockResponse:
|
||||
headers = NOUS_HEADERS
|
||||
|
||||
# Import AIAgent minimally
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Test the parsing directly
|
||||
state = parse_rate_limit_headers(MockResponse.headers, provider="nous")
|
||||
assert state is not None
|
||||
assert state.requests_min.limit == 800
|
||||
assert state.tokens_hour.limit == 336000000
|
||||
|
||||
def test_capture_rate_limits_none_response(self):
|
||||
"""_capture_rate_limits should handle None gracefully."""
|
||||
from agent.rate_limit_tracker import parse_rate_limit_headers
|
||||
# None should not crash
|
||||
result = parse_rate_limit_headers({})
|
||||
assert result is None
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.subdirectory_hints import SubdirectoryHintTracker
|
||||
|
||||
|
|
@ -189,3 +190,45 @@ class TestSubdirectoryHintTracker:
|
|||
"terminal", {"command": "curl https://example.com/frontend/api"}
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPermissionErrorHandling:
|
||||
"""Regression tests for PermissionError in filesystem checks (ref #6214)."""
|
||||
|
||||
def test_is_valid_subdir_permission_error(self, tmp_path):
|
||||
"""_is_valid_subdir should return False when is_dir() raises PermissionError."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
|
||||
restricted = tmp_path / "restricted"
|
||||
restricted.mkdir()
|
||||
with patch.object(Path, "is_dir", side_effect=PermissionError("Permission denied")):
|
||||
assert tracker._is_valid_subdir(restricted) is False
|
||||
|
||||
def test_load_hints_permission_error_on_is_file(self, tmp_path):
|
||||
"""_load_hints_for_directory should skip files when is_file() raises PermissionError."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(tmp_path))
|
||||
restricted = tmp_path / "restricted"
|
||||
restricted.mkdir()
|
||||
original_is_file = Path.is_file
|
||||
def patched_is_file(self):
|
||||
if "restricted" in str(self):
|
||||
raise PermissionError("Permission denied")
|
||||
return original_is_file(self)
|
||||
with patch.object(Path, "is_file", patched_is_file):
|
||||
result = tracker._load_hints_for_directory(restricted)
|
||||
assert result is None
|
||||
|
||||
def test_check_tool_call_survives_inaccessible_path(self, project):
|
||||
"""Full check_tool_call should not crash when a path is inaccessible."""
|
||||
tracker = SubdirectoryHintTracker(working_dir=str(project))
|
||||
original_is_dir = Path.is_dir
|
||||
def patched_is_dir(self):
|
||||
if "backend" in str(self) and "src" not in str(self):
|
||||
raise PermissionError("Permission denied")
|
||||
return original_is_dir(self)
|
||||
with patch.object(Path, "is_dir", patched_is_dir):
|
||||
# Should not raise — gracefully skip the inaccessible directory
|
||||
result = tracker.check_tool_call(
|
||||
"read_file", {"path": str(project / "backend" / "src" / "main.py")}
|
||||
)
|
||||
# Result may be None (backend skipped) — the key point is no crash
|
||||
assert result is None or isinstance(result, str)
|
||||
|
|
|
|||
|
|
@ -2,22 +2,65 @@ import queue
|
|||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import cli as cli_module
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class _FakeBuffer:
|
||||
def __init__(self, text="", cursor_position=None):
|
||||
self.text = text
|
||||
self.cursor_position = len(text) if cursor_position is None else cursor_position
|
||||
|
||||
def reset(self, append_to_history=False):
|
||||
self.text = ""
|
||||
self.cursor_position = 0
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
cli._approval_lock = threading.Lock()
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
cli._modal_input_snapshot = None
|
||||
cli._invalidate = MagicMock()
|
||||
cli._app = SimpleNamespace(invalidate=MagicMock())
|
||||
cli._app = SimpleNamespace(invalidate=MagicMock(), current_buffer=_FakeBuffer())
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliApprovalUi:
|
||||
def test_sudo_prompt_restores_existing_draft_after_response(self):
|
||||
cli = _make_cli_stub()
|
||||
cli._app.current_buffer = _FakeBuffer("draft command", cursor_position=5)
|
||||
result = {}
|
||||
|
||||
def _run_callback():
|
||||
result["value"] = cli._sudo_password_callback()
|
||||
|
||||
with patch.object(cli_module, "_cprint"):
|
||||
thread = threading.Thread(target=_run_callback, daemon=True)
|
||||
thread.start()
|
||||
|
||||
deadline = time.time() + 2
|
||||
while cli._sudo_state is None and time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
|
||||
assert cli._sudo_state is not None
|
||||
assert cli._app.current_buffer.text == ""
|
||||
|
||||
cli._app.current_buffer.text = "secret"
|
||||
cli._app.current_buffer.cursor_position = len("secret")
|
||||
cli._sudo_state["response_queue"].put("secret")
|
||||
|
||||
thread.join(timeout=2)
|
||||
|
||||
assert result["value"] == "secret"
|
||||
assert cli._app.current_buffer.text == "draft command"
|
||||
assert cli._app.current_buffer.cursor_position == 5
|
||||
|
||||
def test_approval_callback_includes_view_for_long_commands(self):
|
||||
cli = _make_cli_stub()
|
||||
command = "sudo dd if=/tmp/githubcli-keyring.gpg of=/usr/share/keyrings/githubcli-archive-keyring.gpg bs=4M status=progress"
|
||||
|
|
|
|||
361
tests/gateway/test_bluebubbles.py
Normal file
361
tests/gateway/test_bluebubbles.py
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
"""Tests for the BlueBubbles iMessage gateway adapter."""
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter(monkeypatch, **extra):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.platforms.bluebubbles import BlueBubblesAdapter
|
||||
|
||||
cfg = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"server_url": "http://localhost:1234",
|
||||
"password": "secret",
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
return BlueBubblesAdapter(cfg)
|
||||
|
||||
|
||||
class TestBlueBubblesPlatformEnum:
|
||||
def test_bluebubbles_enum_exists(self):
|
||||
assert Platform.BLUEBUBBLES.value == "bluebubbles"
|
||||
|
||||
|
||||
class TestBlueBubblesConfigLoading:
|
||||
def test_apply_env_overrides_bluebubbles(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
monkeypatch.setenv("BLUEBUBBLES_WEBHOOK_PORT", "9999")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES in config.platforms
|
||||
bc = config.platforms[Platform.BLUEBUBBLES]
|
||||
assert bc.enabled is True
|
||||
assert bc.extra["server_url"] == "http://localhost:1234"
|
||||
assert bc.extra["password"] == "secret"
|
||||
assert bc.extra["webhook_port"] == 9999
|
||||
|
||||
def test_connected_platforms_includes_bluebubbles(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES in config.get_connected_platforms()
|
||||
|
||||
def test_home_channel_set_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
monkeypatch.setenv("BLUEBUBBLES_HOME_CHANNEL", "user@example.com")
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
hc = config.platforms[Platform.BLUEBUBBLES].home_channel
|
||||
assert hc is not None
|
||||
assert hc.chat_id == "user@example.com"
|
||||
|
||||
def test_not_connected_without_password(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False)
|
||||
from gateway.config import GatewayConfig, _apply_env_overrides
|
||||
|
||||
config = GatewayConfig()
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.BLUEBUBBLES not in config.get_connected_platforms()
|
||||
|
||||
|
||||
class TestBlueBubblesHelpers:
|
||||
def test_check_requirements(self, monkeypatch):
|
||||
monkeypatch.setenv("BLUEBUBBLES_SERVER_URL", "http://localhost:1234")
|
||||
monkeypatch.setenv("BLUEBUBBLES_PASSWORD", "secret")
|
||||
from gateway.platforms.bluebubbles import check_bluebubbles_requirements
|
||||
|
||||
assert check_bluebubbles_requirements() is True
|
||||
|
||||
def test_format_message_strips_markdown(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("**Hello** `world`") == "Hello world"
|
||||
|
||||
def test_strip_markdown_headers(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("## Heading\ntext") == "Heading\ntext"
|
||||
|
||||
def test_strip_markdown_links(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
assert adapter.format_message("[click here](http://example.com)") == "click here"
|
||||
|
||||
def test_init_normalizes_webhook_path(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, webhook_path="bluebubbles-webhook")
|
||||
assert adapter.webhook_path == "/bluebubbles-webhook"
|
||||
|
||||
def test_init_preserves_leading_slash(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, webhook_path="/my-hook")
|
||||
assert adapter.webhook_path == "/my-hook"
|
||||
|
||||
def test_server_url_normalized(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, server_url="http://localhost:1234/")
|
||||
assert adapter.server_url == "http://localhost:1234"
|
||||
|
||||
def test_server_url_adds_scheme(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch, server_url="localhost:1234")
|
||||
assert adapter.server_url == "http://localhost:1234"
|
||||
|
||||
|
||||
class TestBlueBubblesWebhookParsing:
|
||||
def test_webhook_prefers_chat_guid_over_message_guid(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"guid": "MESSAGE-GUID",
|
||||
"chatGuid": "iMessage;-;user@example.com",
|
||||
"chatIdentifier": "user@example.com",
|
||||
}
|
||||
record = adapter._extract_payload_record(payload) or {}
|
||||
chat_guid = adapter._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
assert chat_guid == "iMessage;-;user@example.com"
|
||||
|
||||
def test_webhook_can_fall_back_to_sender_when_chat_fields_missing(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"data": {
|
||||
"guid": "MESSAGE-GUID",
|
||||
"text": "hello",
|
||||
"handle": {"address": "user@example.com"},
|
||||
"isFromMe": False,
|
||||
}
|
||||
}
|
||||
record = adapter._extract_payload_record(payload) or {}
|
||||
chat_guid = adapter._value(
|
||||
record.get("chatGuid"),
|
||||
payload.get("chatGuid"),
|
||||
record.get("chat_guid"),
|
||||
payload.get("chat_guid"),
|
||||
payload.get("guid"),
|
||||
)
|
||||
chat_identifier = adapter._value(
|
||||
record.get("chatIdentifier"),
|
||||
record.get("identifier"),
|
||||
payload.get("chatIdentifier"),
|
||||
payload.get("identifier"),
|
||||
)
|
||||
sender = (
|
||||
adapter._value(
|
||||
record.get("handle", {}).get("address")
|
||||
if isinstance(record.get("handle"), dict)
|
||||
else None,
|
||||
record.get("sender"),
|
||||
record.get("from"),
|
||||
record.get("address"),
|
||||
)
|
||||
or chat_identifier
|
||||
or chat_guid
|
||||
)
|
||||
if not (chat_guid or chat_identifier) and sender:
|
||||
chat_identifier = sender
|
||||
assert chat_identifier == "user@example.com"
|
||||
|
||||
def test_extract_payload_record_accepts_list_data(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {
|
||||
"type": "new-message",
|
||||
"data": [
|
||||
{
|
||||
"text": "hello",
|
||||
"chatGuid": "iMessage;-;user@example.com",
|
||||
"chatIdentifier": "user@example.com",
|
||||
}
|
||||
],
|
||||
}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record == payload["data"][0]
|
||||
|
||||
def test_extract_payload_record_dict_data(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {"data": {"text": "hello", "chatGuid": "iMessage;-;+1234"}}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record["text"] == "hello"
|
||||
|
||||
def test_extract_payload_record_fallback_to_message(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
payload = {"message": {"text": "hello"}}
|
||||
record = adapter._extract_payload_record(payload)
|
||||
assert record["text"] == "hello"
|
||||
|
||||
|
||||
class TestBlueBubblesGuidResolution:
|
||||
def test_raw_guid_returned_as_is(self, monkeypatch):
|
||||
"""If target already contains ';' it's a raw GUID — return unchanged."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._resolve_chat_guid("iMessage;-;user@example.com")
|
||||
)
|
||||
assert result == "iMessage;-;user@example.com"
|
||||
|
||||
def test_empty_target_returns_none(self, monkeypatch):
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._resolve_chat_guid("")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBlueBubblesToolsetIntegration:
|
||||
def test_toolset_exists(self):
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
assert "hermes-bluebubbles" in TOOLSETS
|
||||
|
||||
def test_toolset_in_gateway_composite(self):
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
gateway = TOOLSETS["hermes-gateway"]
|
||||
assert "hermes-bluebubbles" in gateway["includes"]
|
||||
|
||||
|
||||
class TestBlueBubblesPromptHint:
|
||||
def test_platform_hint_exists(self):
|
||||
from agent.prompt_builder import PLATFORM_HINTS
|
||||
|
||||
assert "bluebubbles" in PLATFORM_HINTS
|
||||
hint = PLATFORM_HINTS["bluebubbles"]
|
||||
assert "iMessage" in hint
|
||||
assert "plain text" in hint
|
||||
|
||||
|
||||
class TestBlueBubblesAttachmentDownload:
|
||||
"""Verify _download_attachment routes to the correct cache helper."""
|
||||
|
||||
def test_download_image_uses_image_cache(self, monkeypatch):
|
||||
"""Image MIME routes to cache_image_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
# Mock the HTTP client response
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_image(data, ext):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/test_image{ext}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_image_from_bytes",
|
||||
mock_cache_image,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "image/png", "transferName": "photo.png"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-123", att_meta)
|
||||
)
|
||||
assert result == "/tmp/test_image.png"
|
||||
|
||||
def test_download_audio_uses_audio_cache(self, monkeypatch):
|
||||
"""Audio MIME routes to cache_audio_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"fake-audio-data"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_audio(data, ext):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/test_audio{ext}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_audio_from_bytes",
|
||||
mock_cache_audio,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "audio/mpeg", "transferName": "voice.mp3"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-456", att_meta)
|
||||
)
|
||||
assert result == "/tmp/test_audio.mp3"
|
||||
|
||||
def test_download_document_uses_document_cache(self, monkeypatch):
|
||||
"""Non-image/audio MIME routes to cache_document_from_bytes."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
import asyncio
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
content = b"fake-doc-data"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
adapter.client = type("MockClient", (), {"get": mock_get})()
|
||||
|
||||
cached_path = None
|
||||
|
||||
def mock_cache_doc(data, filename):
|
||||
nonlocal cached_path
|
||||
cached_path = f"/tmp/{filename}"
|
||||
return cached_path
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.bluebubbles.cache_document_from_bytes",
|
||||
mock_cache_doc,
|
||||
)
|
||||
|
||||
att_meta = {"mimeType": "application/pdf", "transferName": "report.pdf"}
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid-789", att_meta)
|
||||
)
|
||||
assert result == "/tmp/report.pdf"
|
||||
|
||||
def test_download_returns_none_without_client(self, monkeypatch):
|
||||
"""No client → returns None gracefully."""
|
||||
adapter = _make_adapter(monkeypatch)
|
||||
adapter.client = None
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter._download_attachment("att-guid", {"mimeType": "image/png"})
|
||||
)
|
||||
assert result is None
|
||||
|
|
@ -209,14 +209,31 @@ class TestIncomingDocumentHandling:
|
|||
assert "[Content of readme.md]:" in event.text
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_content_injected(self, adapter):
|
||||
""".log file under 100KB should be treated as text/plain and injected."""
|
||||
file_content = b"BLE trace line 1\nBLE trace line 2"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="btsnoop_hci.log", content_type="text/plain")],
|
||||
content="please inspect this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of btsnoop_hci.log]:" in event.text
|
||||
assert "BLE trace line 1" in event.text
|
||||
assert "please inspect this" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped — media_urls stays empty."""
|
||||
"""A document over 32MB should be skipped — media_urls stays empty."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.pdf",
|
||||
content_type="application/pdf",
|
||||
size=25 * 1024 * 1024,
|
||||
size=33 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
|
@ -226,6 +243,24 @@ class TestIncomingDocumentHandling:
|
|||
# handler must still be called
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_sized_zip_under_32mb_is_cached(self, adapter):
|
||||
"""A 25MB .zip should be accepted now that Discord documents allow up to 32MB."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="bugreport.zip",
|
||||
content_type="application/zip",
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
|
||||
with _mock_aiohttp_download(b"PK\x03\x04test"):
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert event.media_types == ["application/zip"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zip_document_cached(self, adapter):
|
||||
"""A .zip file should be cached as a supported document."""
|
||||
|
|
|
|||
315
tests/gateway/test_gateway_inactivity_timeout.py
Normal file
315
tests/gateway/test_gateway_inactivity_timeout.py
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
"""Tests for staged inactivity timeout in gateway agent runs.
|
||||
|
||||
Tests cover:
|
||||
- Warning fires once when inactivity reaches gateway_timeout_warning threshold
|
||||
- Warning does not fire when gateway_timeout is 0 (unlimited)
|
||||
- Warning fires only once per run, not on every poll
|
||||
- Full timeout still fires at gateway_timeout threshold
|
||||
- Warning respects HERMES_AGENT_TIMEOUT_WARNING env var
|
||||
- Warning disabled when gateway_timeout_warning is 0
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
class FakeAgent:
|
||||
"""Mock agent with controllable activity summary for timeout tests."""
|
||||
|
||||
def __init__(self, idle_seconds=0.0, activity_desc="tool_call",
|
||||
current_tool=None, api_call_count=5, max_iterations=90):
|
||||
self._idle_seconds = idle_seconds
|
||||
self._activity_desc = activity_desc
|
||||
self._current_tool = current_tool
|
||||
self._api_call_count = api_call_count
|
||||
self._max_iterations = max_iterations
|
||||
self._interrupted = False
|
||||
self._interrupt_msg = None
|
||||
|
||||
def get_activity_summary(self):
|
||||
return {
|
||||
"last_activity_ts": time.time() - self._idle_seconds,
|
||||
"last_activity_desc": self._activity_desc,
|
||||
"seconds_since_activity": self._idle_seconds,
|
||||
"current_tool": self._current_tool,
|
||||
"api_call_count": self._api_call_count,
|
||||
"max_iterations": self._max_iterations,
|
||||
}
|
||||
|
||||
def interrupt(self, msg):
|
||||
self._interrupted = True
|
||||
self._interrupt_msg = msg
|
||||
|
||||
def run_conversation(self, prompt):
|
||||
return {"final_response": "Done", "messages": []}
|
||||
|
||||
|
||||
class SlowFakeAgent(FakeAgent):
|
||||
"""Agent that runs for a while, then goes idle."""
|
||||
|
||||
def __init__(self, run_duration=0.5, idle_after=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._run_duration = run_duration
|
||||
self._idle_after = idle_after
|
||||
self._start_time = None
|
||||
|
||||
def get_activity_summary(self):
|
||||
summary = super().get_activity_summary()
|
||||
if self._idle_after is not None and self._start_time:
|
||||
elapsed = time.time() - self._start_time
|
||||
if elapsed > self._idle_after:
|
||||
idle_time = elapsed - self._idle_after
|
||||
summary["seconds_since_activity"] = idle_time
|
||||
summary["last_activity_desc"] = "api_call_streaming"
|
||||
else:
|
||||
summary["seconds_since_activity"] = 0.0
|
||||
return summary
|
||||
|
||||
def run_conversation(self, prompt):
|
||||
self._start_time = time.time()
|
||||
time.sleep(self._run_duration)
|
||||
return {"final_response": "Completed after work", "messages": []}
|
||||
|
||||
|
||||
class TestStagedInactivityWarning:
|
||||
"""Test the staged inactivity warning before full timeout."""
|
||||
|
||||
def test_warning_fires_once_before_timeout(self):
|
||||
"""Warning fires when inactivity reaches warning threshold."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="api_call_streaming",
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 5.0
|
||||
_POLL_INTERVAL = 0.1
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test prompt")
|
||||
_inactivity_timeout = False
|
||||
_warning_fired = False
|
||||
_warning_send_count = 0
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
result = future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
_warning_send_count += 1
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
assert _warning_fired
|
||||
assert _warning_send_count == 1
|
||||
assert not _inactivity_timeout
|
||||
|
||||
def test_warning_disabled_when_zero(self):
|
||||
"""No warning fires when gateway_timeout_warning is 0."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=5.0,
|
||||
idle_after=0.1,
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 0.0
|
||||
_POLL_INTERVAL = 0.1
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert not _warning_fired
|
||||
|
||||
def test_warning_fires_only_once(self):
|
||||
"""Warning fires exactly once even if agent remains idle."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.05,
|
||||
)
|
||||
|
||||
_agent_timeout = 20.0
|
||||
_agent_warning = 0.2
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_count = 0
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_count and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_count += 1
|
||||
if _idle_secs >= _agent_timeout:
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_count == 1
|
||||
|
||||
def test_full_timeout_still_fires_after_warning(self):
|
||||
"""Full timeout fires even after warning was sent."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=15.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="waiting for provider response (streaming)",
|
||||
)
|
||||
|
||||
_agent_timeout = 1.0
|
||||
_agent_warning = 0.3
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_inactivity_timeout = False
|
||||
_warning_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_inactivity_timeout = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_fired
|
||||
assert _inactivity_timeout
|
||||
|
||||
def test_warning_env_var_respected(self, monkeypatch):
|
||||
"""HERMES_AGENT_TIMEOUT_WARNING env var is parsed correctly."""
|
||||
monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "600")
|
||||
_warning = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
assert _warning == 600.0
|
||||
|
||||
def test_warning_zero_means_disabled(self, monkeypatch):
|
||||
"""HERMES_AGENT_TIMEOUT_WARNING=0 disables the warning."""
|
||||
monkeypatch.setenv("HERMES_AGENT_TIMEOUT_WARNING", "0")
|
||||
_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||
_warning = _raw if _raw > 0 else None
|
||||
assert _warning is None
|
||||
|
||||
def test_unlimited_timeout_no_warning(self):
|
||||
"""When timeout is unlimited (0), no warning fires either."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=0.5,
|
||||
idle_after=0.0,
|
||||
)
|
||||
|
||||
_agent_timeout = None
|
||||
_agent_warning = 5.0
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
|
||||
result = future.result(timeout=2.0)
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
assert result["final_response"] == "Completed after work"
|
||||
|
||||
|
||||
class TestWarningThresholdBelowTimeout:
|
||||
"""Test that warning threshold must be less than timeout threshold."""
|
||||
|
||||
def test_warning_at_half_timeout(self):
|
||||
"""Warning fires at half the timeout duration."""
|
||||
agent = SlowFakeAgent(
|
||||
run_duration=10.0,
|
||||
idle_after=0.1,
|
||||
activity_desc="receiving stream response",
|
||||
)
|
||||
|
||||
_agent_timeout = 2.0
|
||||
_agent_warning = 1.0
|
||||
_POLL_INTERVAL = 0.05
|
||||
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
future = pool.submit(agent.run_conversation, "test")
|
||||
_warning_fired = False
|
||||
_timeout_fired = False
|
||||
|
||||
while True:
|
||||
done, _ = concurrent.futures.wait({future}, timeout=_POLL_INTERVAL)
|
||||
if done:
|
||||
future.result()
|
||||
break
|
||||
_idle_secs = 0.0
|
||||
if hasattr(agent, "get_activity_summary"):
|
||||
try:
|
||||
_act = agent.get_activity_summary()
|
||||
_idle_secs = _act.get("seconds_since_activity", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
if (not _warning_fired and _agent_warning > 0
|
||||
and _idle_secs >= _agent_warning):
|
||||
_warning_fired = True
|
||||
if _idle_secs >= _agent_timeout:
|
||||
_timeout_fired = True
|
||||
break
|
||||
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
assert _warning_fired
|
||||
assert _timeout_fired
|
||||
226
tests/gateway/test_internal_event_bypass_pairing.py
Normal file
226
tests/gateway/test_internal_event_bypass_pairing.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Tests that internal synthetic events (e.g. background process completion)
|
||||
bypass user authorization and do not trigger DM pairing.
|
||||
|
||||
Regression test for the bug where ``_run_process_watcher`` with
|
||||
``notify_on_complete=True`` injected a ``MessageEvent`` without ``user_id``,
|
||||
causing ``_is_user_authorized`` to reject it and the gateway to send a
|
||||
pairing code to the chat.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Return pre-canned sessions, then None once exhausted."""
|
||||
|
||||
def __init__(self, sessions):
|
||||
self._sessions = list(sessions)
|
||||
|
||||
def get(self, session_id):
|
||||
if self._sessions:
|
||||
return self._sessions.pop(0)
|
||||
return None
|
||||
|
||||
|
||||
def _build_runner(monkeypatch, tmp_path) -> GatewayRunner:
|
||||
"""Create a GatewayRunner with notifications set to 'all'."""
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n background_process_notifications: all\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
return runner
|
||||
|
||||
|
||||
def _watcher_dict_with_notify():
|
||||
return {
|
||||
"session_id": "proc_test_internal",
|
||||
"check_interval": 0,
|
||||
"session_key": "agent:main:discord:dm:123",
|
||||
"platform": "discord",
|
||||
"chat_id": "123",
|
||||
"thread_id": "",
|
||||
"notify_on_complete": True,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_sets_internal_flag(monkeypatch, tmp_path):
|
||||
"""Synthetic completion event must have internal=True."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path)
|
||||
adapter = runner.adapters[Platform.DISCORD]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict_with_notify())
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert isinstance(event, MessageEvent)
|
||||
assert event.internal is True, "Synthetic completion event must be marked internal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_event_bypasses_authorization(monkeypatch, tmp_path):
|
||||
"""An internal event should skip _is_user_authorized entirely."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
|
||||
# Create an internal event with no user_id (simulates the bug scenario)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="[SYSTEM: Background process completed]",
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
|
||||
# Track if _is_user_authorized is called
|
||||
auth_called = False
|
||||
original_auth = GatewayRunner._is_user_authorized
|
||||
|
||||
def tracking_auth(self, src):
|
||||
nonlocal auth_called
|
||||
auth_called = True
|
||||
return original_auth(self, src)
|
||||
|
||||
monkeypatch.setattr(GatewayRunner, "_is_user_authorized", tracking_auth)
|
||||
|
||||
# _handle_message will proceed past auth check and eventually fail on
|
||||
# downstream logic. We just need to verify auth is skipped.
|
||||
try:
|
||||
await runner._handle_message(event)
|
||||
except Exception:
|
||||
pass # Expected — downstream code needs more setup
|
||||
|
||||
assert not auth_called, (
|
||||
"_is_user_authorized should NOT be called for internal events"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
|
||||
"""An internal event with no user_id must not generate a pairing code."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
# Add adapter so pairing would have somewhere to send
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm", # DM would normally trigger pairing
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="[SYSTEM: Background process completed]",
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
|
||||
# Track pairing code generation
|
||||
generate_called = False
|
||||
original_generate = runner.pairing_store.generate_code
|
||||
|
||||
def tracking_generate(*args, **kwargs):
|
||||
nonlocal generate_called
|
||||
generate_called = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
runner.pairing_store.generate_code = tracking_generate
|
||||
|
||||
try:
|
||||
await runner._handle_message(event)
|
||||
except Exception:
|
||||
pass # Expected — downstream code needs more setup
|
||||
|
||||
assert not generate_called, (
|
||||
"Pairing code should NOT be generated for internal events"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
# Clear env vars that could let all users through (loaded by
|
||||
# module-level dotenv in gateway/run.py from the real ~/.hermes/.env).
|
||||
monkeypatch.delenv("DISCORD_ALLOW_ALL_USERS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_ALLOWED_USERS", raising=False)
|
||||
monkeypatch.delenv("GATEWAY_ALLOW_ALL_USERS", raising=False)
|
||||
monkeypatch.delenv("GATEWAY_ALLOWED_USERS", raising=False)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
user_id="unknown_user_999",
|
||||
)
|
||||
# Normal event (not internal)
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
source=source,
|
||||
internal=False,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
# Should return None (unauthorized) and send pairing message
|
||||
assert result is None
|
||||
assert adapter.send.await_count == 1
|
||||
sent_text = adapter.send.await_args.args[1]
|
||||
assert "don't recognize you" in sent_text
|
||||
|
|
@ -707,3 +707,66 @@ class TestSignalSendDocumentViaHelper:
|
|||
|
||||
assert result.success is False
|
||||
assert "/nonexistent.pdf" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() returns message_id from timestamp (#4647)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSendReturnsMessageId:
|
||||
"""Signal send() must return a timestamp-based message_id so the stream
|
||||
consumer can follow its edit→fallback path correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_timestamp_as_message_id(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000})
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1712345678000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc({}) # No timestamp key
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_none_message_id_for_non_dict(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc("ok") # Non-dict result
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stop_typing() delegates to _stop_typing_indicator (#4647)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalStopTyping:
|
||||
"""Signal must expose a public stop_typing() so base adapter's
|
||||
_keep_typing finally block can clean up platform-level typing tasks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_typing_calls_private_method(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
await adapter.stop_typing("+155****4567")
|
||||
|
||||
adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567")
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TestAppMentionHandler:
|
|||
"""Verify that the app_mention event handler is registered."""
|
||||
|
||||
def test_app_mention_registered_on_connect(self):
|
||||
"""connect() should register both 'message' and 'app_mention' handlers."""
|
||||
"""connect() should register message + assistant lifecycle handlers."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
|
|
@ -145,6 +145,8 @@ class TestAppMentionHandler:
|
|||
|
||||
assert "message" in registered_events
|
||||
assert "app_mention" in registered_events
|
||||
assert "assistant_thread_started" in registered_events
|
||||
assert "assistant_thread_context_changed" in registered_events
|
||||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
|
|
@ -840,6 +842,114 @@ class TestThreadReplyHandling:
|
|||
adapter.handle_message.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestAssistantThreadLifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAssistantThreadLifecycle:
|
||||
"""Slack Assistant lifecycle events should seed session/user context."""
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_session_store(self):
|
||||
store = MagicMock()
|
||||
store._entries = {}
|
||||
store._ensure_loaded = MagicMock()
|
||||
store.config = MagicMock()
|
||||
store.config.group_sessions_per_user = True
|
||||
store.get_or_create_session = MagicMock()
|
||||
return store
|
||||
|
||||
@pytest.fixture()
|
||||
def assistant_adapter(self, mock_session_store):
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._team_bot_user_ids = {"T_TEAM": "U_BOT"}
|
||||
a._running = True
|
||||
a.handle_message = AsyncMock()
|
||||
a.set_session_store(mock_session_store)
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_event_seeds_session_store(self, assistant_adapter, mock_session_store):
|
||||
event = {
|
||||
"type": "assistant_thread_started",
|
||||
"team_id": "T_TEAM",
|
||||
"assistant_thread": {
|
||||
"channel_id": "D123",
|
||||
"thread_ts": "171.000",
|
||||
"user_id": "U_USER",
|
||||
"context": {"channel_id": "C_ORIGIN"},
|
||||
},
|
||||
}
|
||||
|
||||
await assistant_adapter._handle_assistant_thread_lifecycle_event(event)
|
||||
|
||||
assert assistant_adapter._assistant_threads[("D123", "171.000")]["user_id"] == "U_USER"
|
||||
mock_session_store.get_or_create_session.assert_called_once()
|
||||
source = mock_session_store.get_or_create_session.call_args[0][0]
|
||||
assert source.chat_id == "D123"
|
||||
assert source.chat_type == "dm"
|
||||
assert source.user_id == "U_USER"
|
||||
assert source.thread_id == "171.000"
|
||||
assert source.chat_topic == "C_ORIGIN"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_uses_cached_assistant_thread_identity(self, assistant_adapter):
|
||||
assistant_adapter._assistant_threads[("D123", "171.000")] = {
|
||||
"channel_id": "D123",
|
||||
"thread_ts": "171.000",
|
||||
"user_id": "U_USER",
|
||||
"team_id": "T_TEAM",
|
||||
}
|
||||
assistant_adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
assistant_adapter._app.client.reactions_add = AsyncMock()
|
||||
assistant_adapter._app.client.reactions_remove = AsyncMock()
|
||||
|
||||
event = {
|
||||
"text": "hello from assistant dm",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"thread_ts": "171.000",
|
||||
"ts": "171.111",
|
||||
"team": "T_TEAM",
|
||||
}
|
||||
|
||||
await assistant_adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = assistant_adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_id == "U_USER"
|
||||
assert msg_event.source.thread_id == "171.000"
|
||||
assert msg_event.source.user_name == "Tyler"
|
||||
|
||||
def test_assistant_threads_cache_eviction(self, assistant_adapter):
|
||||
"""Cache should evict oldest entries when exceeding the size limit."""
|
||||
assistant_adapter._ASSISTANT_THREADS_MAX = 10
|
||||
# Fill to the limit
|
||||
for i in range(10):
|
||||
assistant_adapter._cache_assistant_thread_metadata({
|
||||
"channel_id": f"D{i}",
|
||||
"thread_ts": f"{i}.000",
|
||||
"user_id": f"U{i}",
|
||||
})
|
||||
assert len(assistant_adapter._assistant_threads) == 10
|
||||
|
||||
# Adding one more should trigger eviction (down to max // 2 = 5)
|
||||
assistant_adapter._cache_assistant_thread_metadata({
|
||||
"channel_id": "D999",
|
||||
"thread_ts": "999.000",
|
||||
"user_id": "U999",
|
||||
})
|
||||
assert len(assistant_adapter._assistant_threads) <= 10
|
||||
# The newest entry must survive eviction
|
||||
assert ("D999", "999.000") in assistant_adapter._assistant_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUserNameResolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -383,6 +383,60 @@ class TestSegmentBreakOnToolBoundary:
|
|||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "Next segment"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_enters_fallback_mode(self):
|
||||
"""Platform returns success but no message_id (Signal) — must not
|
||||
re-send on every delta. Should enter fallback mode and send only
|
||||
the continuation at finish."""
|
||||
adapter = MagicMock()
|
||||
# First send succeeds but returns no message_id (Signal behavior)
|
||||
send_result_no_id = SimpleNamespace(success=True, message_id=None)
|
||||
# Fallback final send succeeds
|
||||
send_result_final = SimpleNamespace(success=True, message_id="msg_final")
|
||||
adapter.send = AsyncMock(side_effect=[send_result_no_id, send_result_final])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world, this is a longer response.")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
# Should send exactly 2 messages: initial chunk + fallback continuation
|
||||
# NOT one message per delta
|
||||
assert adapter.send.call_count == 2
|
||||
assert consumer.already_sent
|
||||
# edit_message should NOT have been called (no valid message_id to edit)
|
||||
adapter.edit_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_single_delta_marks_already_sent(self):
|
||||
"""When the entire response fits in one delta and platform returns no
|
||||
message_id, already_sent must still be True to prevent the gateway
|
||||
from re-sending the full response."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id=None)
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Short response.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.already_sent
|
||||
# Only one send call (the initial message)
|
||||
assert adapter.send.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_final_splits_long_continuation_without_dropping_text(self):
|
||||
"""Long continuation tails should be chunked when fallback final-send runs."""
|
||||
|
|
|
|||
70
tests/hermes_cli/test_model_switch_variant_tags.py
Normal file
70
tests/hermes_cli/test_model_switch_variant_tags.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Tests for OpenRouter variant tag preservation in model switching.
|
||||
|
||||
Regression test for GitHub PR #6088 / Discord report: OpenRouter model IDs
|
||||
with variant suffixes like ``:free``, ``:extended``, ``:fast`` were being
|
||||
mangled by the colon-to-slash conversion in model_switch.py Step c.
|
||||
|
||||
The fix: Step c now skips colon→slash conversion when the model name already
|
||||
contains a forward slash (i.e. is already in ``vendor/model`` format), since
|
||||
the colon is a variant tag, not a vendor separator.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.model_switch import switch_model
|
||||
|
||||
|
||||
# Shared mock context — skip network calls, credential resolution, catalog lookups
|
||||
_MOCK_VALIDATION = {"accepted": True, "persist": True, "recognized": True, "message": None}
|
||||
|
||||
|
||||
def _run_switch(raw_input: str, current_provider: str = "openrouter") -> str:
|
||||
"""Run switch_model with mocked dependencies, return the resolved model name."""
|
||||
with patch("hermes_cli.model_switch.resolve_alias", return_value=None), \
|
||||
patch("hermes_cli.model_switch.list_provider_models", return_value=[]), \
|
||||
patch("hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={"api_key": "test", "base_url": "", "api_mode": "chat_completions"}), \
|
||||
patch("hermes_cli.models.validate_requested_model", return_value=_MOCK_VALIDATION), \
|
||||
patch("hermes_cli.model_switch.get_model_info", return_value=None), \
|
||||
patch("hermes_cli.model_switch.get_model_capabilities", return_value=None), \
|
||||
patch("hermes_cli.models.detect_provider_for_model", return_value=None):
|
||||
result = switch_model(
|
||||
raw_input=raw_input,
|
||||
current_provider=current_provider,
|
||||
current_model="anthropic/claude-sonnet-4.6",
|
||||
)
|
||||
assert result.success, f"switch_model failed: {result.error_message}"
|
||||
return result.new_model
|
||||
|
||||
|
||||
class TestVariantTagPreservation:
|
||||
"""OpenRouter variant tags (:free, :extended, :fast) must survive model switching."""
|
||||
|
||||
@pytest.mark.parametrize("model,expected", [
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "nvidia/nemotron-3-super-120b-a12b:free"),
|
||||
("anthropic/claude-sonnet-4.6:extended", "anthropic/claude-sonnet-4.6:extended"),
|
||||
("meta-llama/llama-4-maverick:fast", "meta-llama/llama-4-maverick:fast"),
|
||||
])
|
||||
def test_slash_format_preserves_variant_tag(self, model, expected):
|
||||
"""Models already in vendor/model:tag format must not have their tag mangled."""
|
||||
assert _run_switch(model) == expected
|
||||
|
||||
def test_legacy_colon_format_converts_to_slash(self):
|
||||
"""Legacy vendor:model (no slash) should still be converted to vendor/model."""
|
||||
result = _run_switch("nvidia:nemotron-3-super-120b-a12b")
|
||||
assert result == "nvidia/nemotron-3-super-120b-a12b"
|
||||
|
||||
def test_legacy_colon_format_with_tag_converts_first_colon_only(self):
|
||||
"""vendor:model:free (no slash) → vendor/model:free — first colon becomes slash."""
|
||||
result = _run_switch("nvidia:nemotron-3-super-120b-a12b:free")
|
||||
assert result == "nvidia/nemotron-3-super-120b-a12b:free"
|
||||
|
||||
def test_bare_model_name_unaffected(self):
|
||||
"""Bare model names without colons or slashes should work normally."""
|
||||
result = _run_switch("claude-sonnet-4.6")
|
||||
assert result == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_already_correct_slug_no_tag(self):
|
||||
"""Standard vendor/model slugs without tags pass through unchanged."""
|
||||
result = _run_switch("anthropic/claude-sonnet-4.6")
|
||||
assert result == "anthropic/claude-sonnet-4.6"
|
||||
598
tests/plugins/memory/test_hindsight_provider.py
Normal file
598
tests/plugins/memory/test_hindsight_provider.py
Normal file
|
|
@ -0,0 +1,598 @@
|
|||
"""Tests for the Hindsight memory provider plugin.
|
||||
|
||||
Tests cover config loading, tool handlers (tags, max_tokens, types),
|
||||
prefetch (auto_recall, preamble, query truncation), sync_turn (auto_retain,
|
||||
turn counting, tags), and schema completeness.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.hindsight import (
|
||||
HindsightMemoryProvider,
|
||||
RECALL_SCHEMA,
|
||||
REFLECT_SCHEMA,
|
||||
RETAIN_SCHEMA,
|
||||
_load_config,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_env(monkeypatch):
|
||||
"""Ensure no stale env vars leak between tests."""
|
||||
for key in (
|
||||
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
|
||||
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _make_mock_client():
|
||||
"""Create a mock Hindsight client with async methods."""
|
||||
client = MagicMock()
|
||||
client.aretain = AsyncMock()
|
||||
client.arecall = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
results=[
|
||||
SimpleNamespace(text="Memory 1"),
|
||||
SimpleNamespace(text="Memory 2"),
|
||||
]
|
||||
)
|
||||
)
|
||||
client.areflect = AsyncMock(
|
||||
return_value=SimpleNamespace(text="Synthesized answer")
|
||||
)
|
||||
client.aretain_batch = AsyncMock()
|
||||
client.aclose = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provider(tmp_path, monkeypatch):
|
||||
"""Create an initialized HindsightMemoryProvider with a mock client."""
|
||||
config = {
|
||||
"mode": "cloud",
|
||||
"apiKey": "test-key",
|
||||
"api_url": "http://localhost:9999",
|
||||
"bank_id": "test-bank",
|
||||
"budget": "mid",
|
||||
"memory_mode": "hybrid",
|
||||
}
|
||||
config_path = tmp_path / "hindsight" / "config.json"
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_path.write_text(json.dumps(config))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
|
||||
p = HindsightMemoryProvider()
|
||||
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
|
||||
p._client = _make_mock_client()
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provider_with_config(tmp_path, monkeypatch):
|
||||
"""Create a provider factory that accepts custom config overrides."""
|
||||
def _make(**overrides):
|
||||
config = {
|
||||
"mode": "cloud",
|
||||
"apiKey": "test-key",
|
||||
"api_url": "http://localhost:9999",
|
||||
"bank_id": "test-bank",
|
||||
"budget": "mid",
|
||||
"memory_mode": "hybrid",
|
||||
}
|
||||
config.update(overrides)
|
||||
config_path = tmp_path / "hindsight" / "config.json"
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_path.write_text(json.dumps(config))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
|
||||
p = HindsightMemoryProvider()
|
||||
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
|
||||
p._client = _make_mock_client()
|
||||
return p
|
||||
return _make
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
def test_retain_schema_has_content(self):
|
||||
assert RETAIN_SCHEMA["name"] == "hindsight_retain"
|
||||
assert "content" in RETAIN_SCHEMA["parameters"]["properties"]
|
||||
assert "content" in RETAIN_SCHEMA["parameters"]["required"]
|
||||
|
||||
def test_recall_schema_has_query(self):
|
||||
assert RECALL_SCHEMA["name"] == "hindsight_recall"
|
||||
assert "query" in RECALL_SCHEMA["parameters"]["properties"]
|
||||
assert "query" in RECALL_SCHEMA["parameters"]["required"]
|
||||
|
||||
def test_reflect_schema_has_query(self):
|
||||
assert REFLECT_SCHEMA["name"] == "hindsight_reflect"
|
||||
assert "query" in REFLECT_SCHEMA["parameters"]["properties"]
|
||||
|
||||
def test_get_tool_schemas_returns_three(self, provider):
|
||||
schemas = provider.get_tool_schemas()
|
||||
assert len(schemas) == 3
|
||||
names = {s["name"] for s in schemas}
|
||||
assert names == {"hindsight_retain", "hindsight_recall", "hindsight_reflect"}
|
||||
|
||||
def test_context_mode_returns_no_tools(self, provider_with_config):
|
||||
p = provider_with_config(memory_mode="context")
|
||||
assert p.get_tool_schemas() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def test_default_values(self, provider):
|
||||
assert provider._auto_retain is True
|
||||
assert provider._auto_recall is True
|
||||
assert provider._retain_every_n_turns == 1
|
||||
assert provider._recall_max_tokens == 4096
|
||||
assert provider._recall_max_input_chars == 800
|
||||
assert provider._tags is None
|
||||
assert provider._recall_tags is None
|
||||
assert provider._bank_mission == ""
|
||||
assert provider._bank_retain_mission is None
|
||||
assert provider._retain_context == "conversation between Hermes Agent and the User"
|
||||
|
||||
def test_custom_config_values(self, provider_with_config):
|
||||
p = provider_with_config(
|
||||
tags=["tag1", "tag2"],
|
||||
recall_tags=["recall-tag"],
|
||||
recall_tags_match="all",
|
||||
auto_retain=False,
|
||||
auto_recall=False,
|
||||
retain_every_n_turns=3,
|
||||
retain_context="custom-ctx",
|
||||
bank_retain_mission="Extract key facts",
|
||||
recall_max_tokens=2048,
|
||||
recall_types=["world", "experience"],
|
||||
recall_prompt_preamble="Custom preamble:",
|
||||
recall_max_input_chars=500,
|
||||
bank_mission="Test agent mission",
|
||||
)
|
||||
assert p._tags == ["tag1", "tag2"]
|
||||
assert p._recall_tags == ["recall-tag"]
|
||||
assert p._recall_tags_match == "all"
|
||||
assert p._auto_retain is False
|
||||
assert p._auto_recall is False
|
||||
assert p._retain_every_n_turns == 3
|
||||
assert p._retain_context == "custom-ctx"
|
||||
assert p._bank_retain_mission == "Extract key facts"
|
||||
assert p._recall_max_tokens == 2048
|
||||
assert p._recall_types == ["world", "experience"]
|
||||
assert p._recall_prompt_preamble == "Custom preamble:"
|
||||
assert p._recall_max_input_chars == 500
|
||||
assert p._bank_mission == "Test agent mission"
|
||||
|
||||
def test_config_from_env_fallback(self, tmp_path, monkeypatch):
|
||||
"""When no config file exists, falls back to env vars."""
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home",
|
||||
lambda: tmp_path / "nonexistent",
|
||||
)
|
||||
monkeypatch.setenv("HINDSIGHT_MODE", "cloud")
|
||||
monkeypatch.setenv("HINDSIGHT_API_KEY", "env-key")
|
||||
monkeypatch.setenv("HINDSIGHT_BANK_ID", "env-bank")
|
||||
monkeypatch.setenv("HINDSIGHT_BUDGET", "high")
|
||||
|
||||
cfg = _load_config()
|
||||
assert cfg["apiKey"] == "env-key"
|
||||
assert cfg["banks"]["hermes"]["bankId"] == "env-bank"
|
||||
assert cfg["banks"]["hermes"]["budget"] == "high"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool handler tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolHandlers:
|
||||
def test_retain_success(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_retain", {"content": "user likes dark mode"}
|
||||
))
|
||||
assert result["result"] == "Memory stored successfully."
|
||||
provider._client.aretain.assert_called_once()
|
||||
call_kwargs = provider._client.aretain.call_args.kwargs
|
||||
assert call_kwargs["bank_id"] == "test-bank"
|
||||
assert call_kwargs["content"] == "user likes dark mode"
|
||||
|
||||
def test_retain_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(tags=["pref", "ui"])
|
||||
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
|
||||
call_kwargs = p._client.aretain.call_args.kwargs
|
||||
assert call_kwargs["tags"] == ["pref", "ui"]
|
||||
|
||||
def test_retain_without_tags(self, provider):
|
||||
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
|
||||
call_kwargs = provider._client.aretain.call_args.kwargs
|
||||
assert "tags" not in call_kwargs
|
||||
|
||||
def test_retain_missing_content(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_retain", {}
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
def test_recall_success(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_recall", {"query": "dark mode"}
|
||||
))
|
||||
assert "Memory 1" in result["result"]
|
||||
assert "Memory 2" in result["result"]
|
||||
|
||||
def test_recall_passes_max_tokens(self, provider_with_config):
|
||||
p = provider_with_config(recall_max_tokens=2048)
|
||||
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
||||
call_kwargs = p._client.arecall.call_args.kwargs
|
||||
assert call_kwargs["max_tokens"] == 2048
|
||||
|
||||
def test_recall_passes_tags(self, provider_with_config):
|
||||
p = provider_with_config(recall_tags=["tag1"], recall_tags_match="all")
|
||||
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
||||
call_kwargs = p._client.arecall.call_args.kwargs
|
||||
assert call_kwargs["tags"] == ["tag1"]
|
||||
assert call_kwargs["tags_match"] == "all"
|
||||
|
||||
def test_recall_passes_types(self, provider_with_config):
|
||||
p = provider_with_config(recall_types=["world", "experience"])
|
||||
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
||||
call_kwargs = p._client.arecall.call_args.kwargs
|
||||
assert call_kwargs["types"] == ["world", "experience"]
|
||||
|
||||
def test_recall_no_results(self, provider):
|
||||
provider._client.arecall.return_value = SimpleNamespace(results=[])
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_recall", {"query": "test"}
|
||||
))
|
||||
assert result["result"] == "No relevant memories found."
|
||||
|
||||
def test_recall_missing_query(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_recall", {}
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
def test_reflect_success(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_reflect", {"query": "summarize"}
|
||||
))
|
||||
assert result["result"] == "Synthesized answer"
|
||||
|
||||
def test_reflect_missing_query(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_reflect", {}
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
def test_unknown_tool(self, provider):
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_unknown", {}
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
def test_retain_error_handling(self, provider):
|
||||
provider._client.aretain.side_effect = RuntimeError("connection failed")
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_retain", {"content": "test"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert "connection failed" in result["error"]
|
||||
|
||||
def test_recall_error_handling(self, provider):
|
||||
provider._client.arecall.side_effect = RuntimeError("timeout")
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_recall", {"query": "test"}
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefetch tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrefetch:
|
||||
def test_prefetch_returns_empty_when_no_result(self, provider):
|
||||
assert provider.prefetch("test") == ""
|
||||
|
||||
def test_prefetch_default_preamble(self, provider):
|
||||
provider._prefetch_result = "- some memory"
|
||||
result = provider.prefetch("test")
|
||||
assert "Hindsight Memory" in result
|
||||
assert "- some memory" in result
|
||||
|
||||
def test_prefetch_custom_preamble(self, provider_with_config):
|
||||
p = provider_with_config(recall_prompt_preamble="Custom header:")
|
||||
p._prefetch_result = "- memory line"
|
||||
result = p.prefetch("test")
|
||||
assert result.startswith("Custom header:")
|
||||
assert "- memory line" in result
|
||||
|
||||
def test_queue_prefetch_skipped_in_tools_mode(self, provider_with_config):
|
||||
p = provider_with_config(memory_mode="tools")
|
||||
p.queue_prefetch("test")
|
||||
# Should not start a thread
|
||||
assert p._prefetch_thread is None
|
||||
|
||||
def test_queue_prefetch_skipped_when_auto_recall_off(self, provider_with_config):
|
||||
p = provider_with_config(auto_recall=False)
|
||||
p.queue_prefetch("test")
|
||||
assert p._prefetch_thread is None
|
||||
|
||||
def test_queue_prefetch_truncates_query(self, provider_with_config):
|
||||
p = provider_with_config(recall_max_input_chars=10)
|
||||
# Mock _run_sync to capture the query
|
||||
original_query = None
|
||||
|
||||
def _capture_recall(**kwargs):
|
||||
nonlocal original_query
|
||||
original_query = kwargs.get("query", "")
|
||||
return SimpleNamespace(results=[])
|
||||
|
||||
p._client.arecall = AsyncMock(side_effect=_capture_recall)
|
||||
|
||||
long_query = "a" * 100
|
||||
p.queue_prefetch(long_query)
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=5.0)
|
||||
|
||||
# The query passed to arecall should be truncated
|
||||
if original_query is not None:
|
||||
assert len(original_query) <= 10
|
||||
|
||||
def test_queue_prefetch_passes_recall_params(self, provider_with_config):
|
||||
p = provider_with_config(
|
||||
recall_tags=["t1"],
|
||||
recall_tags_match="all",
|
||||
recall_max_tokens=1024,
|
||||
recall_types=["world"],
|
||||
)
|
||||
p.queue_prefetch("test query")
|
||||
if p._prefetch_thread:
|
||||
p._prefetch_thread.join(timeout=5.0)
|
||||
|
||||
call_kwargs = p._client.arecall.call_args.kwargs
|
||||
assert call_kwargs["max_tokens"] == 1024
|
||||
assert call_kwargs["tags"] == ["t1"]
|
||||
assert call_kwargs["tags_match"] == "all"
|
||||
assert call_kwargs["types"] == ["world"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sync_turn tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncTurn:
|
||||
def _get_retain_kwargs(self, provider):
|
||||
"""Helper to get the kwargs from the aretain_batch call."""
|
||||
return provider._client.aretain_batch.call_args.kwargs
|
||||
|
||||
def _get_retain_content(self, provider):
|
||||
"""Helper to get the raw content string from the first item."""
|
||||
kwargs = self._get_retain_kwargs(provider)
|
||||
return kwargs["items"][0]["content"]
|
||||
|
||||
def _get_retain_messages(self, provider):
|
||||
"""Helper to parse the first turn's messages from retained content.
|
||||
|
||||
Content is a JSON array of turns: [[msgs...], [msgs...], ...]
|
||||
For single-turn tests, returns the first turn's messages.
|
||||
"""
|
||||
content = self._get_retain_content(provider)
|
||||
turns = json.loads(content)
|
||||
return turns[0] if len(turns) == 1 else turns
|
||||
|
||||
def test_sync_turn_retains(self, provider):
|
||||
provider.sync_turn("hello", "hi there")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
messages = self._get_retain_messages(provider)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert "timestamp" in messages[0]
|
||||
assert messages[1]["role"] == "assistant"
|
||||
assert messages[1]["content"] == "hi there"
|
||||
assert "timestamp" in messages[1]
|
||||
|
||||
def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config):
|
||||
p = provider_with_config(auto_retain=False)
|
||||
p.sync_turn("hello", "hi")
|
||||
assert p._sync_thread is None
|
||||
p._client.aretain_batch.assert_not_called()
|
||||
|
||||
def test_sync_turn_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(tags=["conv", "session1"])
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["tags"] == ["conv", "session1"]
|
||||
|
||||
def test_sync_turn_uses_aretain_batch(self, provider):
|
||||
"""sync_turn should use aretain_batch with retain_async."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"] == "test-session"
|
||||
assert call_kwargs["retain_async"] is True
|
||||
assert len(call_kwargs["items"]) == 1
|
||||
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
|
||||
|
||||
def test_sync_turn_custom_context(self, provider_with_config):
|
||||
p = provider_with_config(retain_context="my-agent")
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["context"] == "my-agent"
|
||||
|
||||
def test_sync_turn_every_n_turns(self, provider_with_config):
|
||||
"""With retain_every_n_turns=3, only retains on every 3rd turn."""
|
||||
p = provider_with_config(retain_every_n_turns=3)
|
||||
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
assert p._sync_thread is None # not retained yet
|
||||
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None # not retained yet
|
||||
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
assert p._sync_thread is not None # retained!
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain all 3 turns
|
||||
assert "turn1-user" in content
|
||||
assert "turn2-user" in content
|
||||
assert "turn3-user" in content
|
||||
|
||||
def test_sync_turn_accumulates_full_session(self, provider_with_config):
|
||||
"""Each retain sends the ENTIRE session, not just the latest batch."""
|
||||
p = provider_with_config(retain_every_n_turns=2)
|
||||
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch.reset_mock()
|
||||
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p.sync_turn("turn4-user", "turn4-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain ALL turns from the session
|
||||
assert "turn1-user" in content
|
||||
assert "turn2-user" in content
|
||||
assert "turn3-user" in content
|
||||
assert "turn4-user" in content
|
||||
|
||||
def test_sync_turn_passes_document_id(self, provider):
|
||||
"""sync_turn should pass session_id as document_id for dedup."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"] == "test-session"
|
||||
|
||||
def test_sync_turn_error_does_not_raise(self, provider):
|
||||
"""Errors in sync_turn should be swallowed (non-blocking)."""
|
||||
provider._client.aretain_batch.side_effect = RuntimeError("network error")
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
# Should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSystemPrompt:
|
||||
def test_hybrid_mode_prompt(self, provider):
|
||||
block = provider.system_prompt_block()
|
||||
assert "Hindsight Memory" in block
|
||||
assert "hindsight_recall" in block
|
||||
assert "automatically injected" in block
|
||||
|
||||
def test_context_mode_prompt(self, provider_with_config):
|
||||
p = provider_with_config(memory_mode="context")
|
||||
block = p.system_prompt_block()
|
||||
assert "context mode" in block
|
||||
assert "hindsight_recall" not in block
|
||||
|
||||
def test_tools_mode_prompt(self, provider_with_config):
|
||||
p = provider_with_config(memory_mode="tools")
|
||||
block = p.system_prompt_block()
|
||||
assert "tools mode" in block
|
||||
assert "hindsight_recall" in block
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigSchema:
|
||||
def test_schema_has_all_new_fields(self, provider):
|
||||
schema = provider.get_config_schema()
|
||||
keys = {f["key"] for f in schema}
|
||||
expected_keys = {
|
||||
"mode", "api_url", "api_key", "llm_provider", "llm_api_key",
|
||||
"llm_model", "bank_id", "bank_mission", "bank_retain_mission",
|
||||
"recall_budget", "memory_mode", "recall_prefetch_method",
|
||||
"tags", "recall_tags", "recall_tags_match",
|
||||
"auto_recall", "auto_retain",
|
||||
"retain_every_n_turns", "retain_async",
|
||||
"retain_context",
|
||||
"recall_max_tokens", "recall_max_input_chars",
|
||||
"recall_prompt_preamble",
|
||||
}
|
||||
assert expected_keys.issubset(keys), f"Missing: {expected_keys - keys}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAvailability:
|
||||
def test_available_with_api_key(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home",
|
||||
lambda: tmp_path / "nonexistent",
|
||||
)
|
||||
monkeypatch.setenv("HINDSIGHT_API_KEY", "test-key")
|
||||
p = HindsightMemoryProvider()
|
||||
assert p.is_available()
|
||||
|
||||
def test_not_available_without_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home",
|
||||
lambda: tmp_path / "nonexistent",
|
||||
)
|
||||
p = HindsightMemoryProvider()
|
||||
assert not p.is_available()
|
||||
|
||||
def test_available_in_local_mode(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.hindsight.get_hermes_home",
|
||||
lambda: tmp_path / "nonexistent",
|
||||
)
|
||||
monkeypatch.setenv("HINDSIGHT_MODE", "local")
|
||||
p = HindsightMemoryProvider()
|
||||
assert p.is_available()
|
||||
|
|
@ -150,8 +150,8 @@ def agent():
|
|||
class TestContextPressureFlags:
|
||||
"""Context pressure warning flag tracking on AIAgent."""
|
||||
|
||||
def test_flag_initialized_false(self, agent):
|
||||
assert agent._context_pressure_warned is False
|
||||
def test_flag_initialized_zero(self, agent):
|
||||
assert agent._context_pressure_warned_at == 0.0
|
||||
|
||||
def test_emit_calls_status_callback(self, agent):
|
||||
"""status_callback should be invoked with event type and message."""
|
||||
|
|
@ -210,7 +210,7 @@ class TestContextPressureFlags:
|
|||
|
||||
def test_flag_reset_on_compression(self, agent):
|
||||
"""After _compress_context, context pressure flag should reset."""
|
||||
agent._context_pressure_warned = True
|
||||
agent._context_pressure_warned_at = 0.85
|
||||
agent.compression_enabled = True
|
||||
|
||||
agent.context_compressor = MagicMock()
|
||||
|
|
@ -219,6 +219,7 @@ class TestContextPressureFlags:
|
|||
]
|
||||
agent.context_compressor.context_length = 200_000
|
||||
agent.context_compressor.threshold_tokens = 100_000
|
||||
agent.context_compressor.compression_count = 1
|
||||
|
||||
agent._todo_store = MagicMock()
|
||||
agent._todo_store.format_for_injection.return_value = None
|
||||
|
|
@ -233,7 +234,7 @@ class TestContextPressureFlags:
|
|||
]
|
||||
agent._compress_context(messages, "system prompt")
|
||||
|
||||
assert agent._context_pressure_warned is False
|
||||
assert agent._context_pressure_warned_at == 0.0
|
||||
|
||||
def test_emit_callback_error_handled(self, agent):
|
||||
"""If status_callback raises, it should be caught gracefully."""
|
||||
|
|
@ -246,3 +247,115 @@ class TestContextPressureFlags:
|
|||
|
||||
# Should not raise
|
||||
agent._emit_context_pressure(0.85, compressor)
|
||||
|
||||
def test_tiered_reemits_at_95(self, agent):
|
||||
"""Warning fires at 85%, then fires again when crossing 95%."""
|
||||
agent._context_pressure_warned_at = 0.85
|
||||
# Simulate crossing 95%: the tier (0.95) > warned_at (0.85)
|
||||
assert 0.95 > agent._context_pressure_warned_at
|
||||
# After emission at 95%, the tier should update
|
||||
agent._context_pressure_warned_at = 0.95
|
||||
assert agent._context_pressure_warned_at == 0.95
|
||||
|
||||
def test_tiered_no_double_emit_at_same_level(self, agent):
|
||||
"""Once warned at 85%, further 85%+ readings don't re-warn."""
|
||||
agent._context_pressure_warned_at = 0.85
|
||||
# At 88%, tier is 0.85, which is NOT > warned_at (0.85)
|
||||
_warn_tier = 0.85 if 0.88 >= 0.85 else 0.0
|
||||
assert not (_warn_tier > agent._context_pressure_warned_at)
|
||||
|
||||
def test_flag_not_reset_when_compression_insufficient(self, agent):
|
||||
"""When compression can't drop below 85%, keep the flag set."""
|
||||
agent._context_pressure_warned_at = 0.85
|
||||
agent.compression_enabled = True
|
||||
|
||||
agent.context_compressor = MagicMock()
|
||||
agent.context_compressor.compress.return_value = [
|
||||
{"role": "user", "content": "Summary of conversation so far."}
|
||||
]
|
||||
agent.context_compressor.context_length = 200
|
||||
# Use a small threshold so the tiny compressed output still
|
||||
# represents >= 85% of it (prevents flag reset).
|
||||
agent.context_compressor.threshold_tokens = 10
|
||||
agent.context_compressor.compression_count = 1
|
||||
agent.context_compressor.last_prompt_tokens = 0
|
||||
|
||||
agent._todo_store = MagicMock()
|
||||
agent._todo_store.format_for_injection.return_value = None
|
||||
agent._build_system_prompt = MagicMock(return_value="system prompt")
|
||||
agent._cached_system_prompt = "old system prompt"
|
||||
agent._session_db = None
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
agent._compress_context(messages, "system prompt")
|
||||
|
||||
# Post-compression is ~90% of threshold — flag should NOT reset
|
||||
assert agent._context_pressure_warned_at == 0.85
|
||||
|
||||
|
||||
class TestContextPressureGatewayDedup:
|
||||
"""Class-level dedup prevents warning spam across AIAgent instances."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear class-level dedup state between tests."""
|
||||
AIAgent._context_pressure_last_warned.clear()
|
||||
|
||||
def test_second_instance_within_cooldown_suppressed(self):
|
||||
"""Same session, same tier, within cooldown — should be suppressed."""
|
||||
import time
|
||||
sid = "test_session_dedup"
|
||||
# Simulate first warning
|
||||
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
|
||||
# Second instance checking same tier within cooldown
|
||||
_last = AIAgent._context_pressure_last_warned.get(sid)
|
||||
_should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
|
||||
assert not _should_warn
|
||||
|
||||
def test_higher_tier_fires_despite_cooldown(self):
|
||||
"""Same session, higher tier — should fire even within cooldown."""
|
||||
import time
|
||||
sid = "test_session_tier"
|
||||
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
|
||||
_last = AIAgent._context_pressure_last_warned.get(sid)
|
||||
# 0.95 > 0.85 stored tier → should warn
|
||||
_should_warn = _last is None or _last[0] < 0.95 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
|
||||
assert _should_warn
|
||||
|
||||
def test_warning_fires_after_cooldown_expires(self):
|
||||
"""Same session, same tier, after cooldown — should fire again."""
|
||||
import time
|
||||
sid = "test_session_expired"
|
||||
# Set a timestamp far in the past
|
||||
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time() - AIAgent._CONTEXT_PRESSURE_COOLDOWN - 1)
|
||||
_last = AIAgent._context_pressure_last_warned.get(sid)
|
||||
_should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
|
||||
assert _should_warn
|
||||
|
||||
def test_compression_clears_dedup(self):
|
||||
"""After compression drops below 85%, dedup entry should be cleared."""
|
||||
import time
|
||||
sid = "test_session_clear"
|
||||
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
|
||||
assert sid in AIAgent._context_pressure_last_warned
|
||||
# Simulate what _compress_context does on reset
|
||||
AIAgent._context_pressure_last_warned.pop(sid, None)
|
||||
assert sid not in AIAgent._context_pressure_last_warned
|
||||
|
||||
def test_eviction_removes_stale_entries(self):
|
||||
"""Stale entries older than 2x cooldown should be evicted."""
|
||||
import time
|
||||
_now = time.time()
|
||||
AIAgent._context_pressure_last_warned = {
|
||||
"fresh": (0.85, _now),
|
||||
"stale": (0.85, _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 3),
|
||||
}
|
||||
_cutoff = _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 2
|
||||
AIAgent._context_pressure_last_warned = {
|
||||
k: v for k, v in AIAgent._context_pressure_last_warned.items()
|
||||
if v[1] > _cutoff
|
||||
}
|
||||
assert "fresh" in AIAgent._context_pressure_last_warned
|
||||
assert "stale" not in AIAgent._context_pressure_last_warned
|
||||
|
|
|
|||
|
|
@ -91,6 +91,61 @@ def _chat_response_with_memory_call():
|
|||
)
|
||||
|
||||
|
||||
class TestFlushMemoriesRespectsConfigTimeout:
|
||||
"""flush_memories() must NOT hardcode timeout=30.0 — it should defer
|
||||
to the config value via auxiliary.flush_memories.timeout."""
|
||||
|
||||
def test_auxiliary_path_omits_explicit_timeout(self, monkeypatch):
|
||||
"""When calling _call_llm, timeout should NOT be passed so that
|
||||
_get_task_timeout('flush_memories') reads from config."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
|
||||
mock_response = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_call:
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "user", "content": "Note this"},
|
||||
]
|
||||
with patch("tools.memory_tool.memory_tool", return_value="Saved."):
|
||||
agent.flush_memories(messages)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_kwargs = mock_call.call_args
|
||||
# timeout must NOT be explicitly passed (so _get_task_timeout resolves it)
|
||||
assert "timeout" not in call_kwargs.kwargs, (
|
||||
"flush_memories should not pass explicit timeout to _call_llm; "
|
||||
"let _get_task_timeout('flush_memories') resolve from config"
|
||||
)
|
||||
|
||||
def test_fallback_path_uses_config_timeout(self, monkeypatch):
|
||||
"""When auxiliary client is unavailable and we fall back to direct
|
||||
OpenAI client, timeout should come from _get_task_timeout, not hardcoded."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
agent.client = MagicMock()
|
||||
agent.client.chat.completions.create.return_value = _chat_response_with_memory_call()
|
||||
|
||||
custom_timeout = 180.0
|
||||
|
||||
with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")), \
|
||||
patch("agent.auxiliary_client._get_task_timeout", return_value=custom_timeout) as mock_gtt, \
|
||||
patch("tools.memory_tool.memory_tool", return_value="Saved."):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "user", "content": "Save this"},
|
||||
]
|
||||
agent.flush_memories(messages)
|
||||
|
||||
mock_gtt.assert_called_once_with("flush_memories")
|
||||
agent.client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs.kwargs.get("timeout") == custom_timeout, (
|
||||
f"Expected timeout={custom_timeout} from config, got {call_kwargs.kwargs.get('timeout')}"
|
||||
)
|
||||
|
||||
|
||||
class TestFlushMemoriesUsesAuxiliaryClient:
|
||||
"""When an auxiliary client is available, flush_memories should use it
|
||||
instead of self.client -- especially critical in Codex mode."""
|
||||
|
|
|
|||
|
|
@ -1668,12 +1668,15 @@ class TestRunConversation:
|
|||
if roles[i] == "assistant" and roles[i + 1] == "assistant":
|
||||
raise AssertionError("Consecutive assistant messages found in history")
|
||||
|
||||
def test_truly_empty_response_accepted_without_retry(self, agent):
|
||||
"""Truly empty response (no content, no reasoning) should still complete with (empty)."""
|
||||
def test_truly_empty_response_retries_3_times_then_empty(self, agent):
|
||||
"""Truly empty response (no content, no reasoning) retries 3 times then falls through to (empty)."""
|
||||
self._setup_agent(agent)
|
||||
agent.base_url = "http://127.0.0.1:1234/v1"
|
||||
empty_resp = _mock_response(content=None, finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [empty_resp]
|
||||
# 4 responses: 1 original + 3 nudge retries, all empty
|
||||
agent.client.chat.completions.create.side_effect = [
|
||||
empty_resp, empty_resp, empty_resp, empty_resp,
|
||||
]
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
|
|
@ -1682,7 +1685,28 @@ class TestRunConversation:
|
|||
result = agent.run_conversation("answer me")
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "(empty)"
|
||||
assert result["api_calls"] == 1 # no retries
|
||||
assert result["api_calls"] == 4 # 1 original + 3 retries
|
||||
|
||||
def test_truly_empty_response_succeeds_on_nudge(self, agent):
|
||||
"""Model produces content after being nudged for empty response."""
|
||||
self._setup_agent(agent)
|
||||
agent.base_url = "http://127.0.0.1:1234/v1"
|
||||
empty_resp = _mock_response(content=None, finish_reason="stop")
|
||||
content_resp = _mock_response(
|
||||
content="Here is the actual answer.",
|
||||
finish_reason="stop",
|
||||
)
|
||||
# 1 empty response, then model produces content on nudge
|
||||
agent.client.chat.completions.create.side_effect = [empty_resp, content_resp]
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("answer me")
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Here is the actual answer."
|
||||
assert result["api_calls"] == 2 # 1 original + 1 nudge retry
|
||||
|
||||
def test_nous_401_refreshes_after_remint_and_retries(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
|
|
|||
|
|
@ -658,6 +658,47 @@ def test_workspace_agents_records_skip_when_missing(tmp_path: Path):
|
|||
assert wa_items[0]["status"] == "skipped"
|
||||
|
||||
|
||||
def test_cron_store_is_archived_without_config_cron_section(tmp_path: Path):
|
||||
"""Bug fix: archive cron store even when openclaw.json has no top-level cron config."""
|
||||
mod = load_module()
|
||||
source = tmp_path / ".openclaw"
|
||||
target = tmp_path / ".hermes"
|
||||
output_dir = target / "migration-report"
|
||||
source.mkdir()
|
||||
target.mkdir()
|
||||
|
||||
(source / "openclaw.json").write_text(json.dumps({"channels": {}}), encoding="utf-8")
|
||||
(source / "cron").mkdir(parents=True)
|
||||
(source / "cron" / "jobs.json").write_text(
|
||||
json.dumps({"version": 1, "jobs": [{"id": "job-1", "name": "demo"}]}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
migrator = mod.Migrator(
|
||||
source_root=source,
|
||||
target_root=target,
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
migrate_secrets=False,
|
||||
output_dir=output_dir,
|
||||
selected_options={"cron-jobs"},
|
||||
)
|
||||
report = migrator.migrate()
|
||||
|
||||
cron_items = [item for item in report["items"] if item["kind"] == "cron-jobs"]
|
||||
archived_store = next(
|
||||
(item for item in cron_items if item["destination"] and item["destination"].endswith("archive/cron-store")),
|
||||
None,
|
||||
)
|
||||
assert archived_store is not None
|
||||
assert Path(archived_store["destination"]).joinpath("jobs.json").exists()
|
||||
|
||||
notes_text = (output_dir / "MIGRATION_NOTES.md").read_text(encoding="utf-8")
|
||||
assert "Run `hermes cron` to recreate scheduled tasks" in notes_text
|
||||
assert "archive/cron-config.json" not in notes_text
|
||||
|
||||
|
||||
def test_skill_installs_cleanly_under_skills_guard():
|
||||
skills_guard = load_skills_guard()
|
||||
result = skills_guard.scan_skill(
|
||||
|
|
|
|||
|
|
@ -663,6 +663,84 @@ class TestPruneSessions:
|
|||
assert db.get_session("old_cli") is None
|
||||
assert db.get_session("old_tg") is not None
|
||||
|
||||
def test_prune_with_multilevel_chain(self, db):
|
||||
"""Pruning old sessions orphans newer children instead of crashing on FK."""
|
||||
old_ts = time.time() - 200 * 86400
|
||||
recent_ts = time.time() - 10 * 86400
|
||||
|
||||
# Chain: A (old) -> B (old) -> C (recent) -> D (recent)
|
||||
db.create_session(session_id="A", source="cli")
|
||||
db.end_session("A", end_reason="compressed")
|
||||
db.create_session(session_id="B", source="cli", parent_session_id="A")
|
||||
db.end_session("B", end_reason="compressed")
|
||||
db.create_session(session_id="C", source="cli", parent_session_id="B")
|
||||
db.end_session("C", end_reason="compressed")
|
||||
db.create_session(session_id="D", source="cli", parent_session_id="C")
|
||||
db.end_session("D", end_reason="done")
|
||||
|
||||
# Backdate A and B to be old; C and D stay recent
|
||||
for sid, ts in [("A", old_ts), ("B", old_ts), ("C", recent_ts), ("D", recent_ts)]:
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET started_at = ? WHERE id = ?", (ts, sid)
|
||||
)
|
||||
db._conn.commit()
|
||||
|
||||
# Should not raise IntegrityError
|
||||
pruned = db.prune_sessions(older_than_days=90)
|
||||
assert pruned == 2 # only A and B
|
||||
assert db.get_session("A") is None
|
||||
assert db.get_session("B") is None
|
||||
# C and D survive, C is orphaned (parent_session_id NULL)
|
||||
c = db.get_session("C")
|
||||
assert c is not None
|
||||
assert c["parent_session_id"] is None
|
||||
d = db.get_session("D")
|
||||
assert d is not None
|
||||
assert d["parent_session_id"] == "C"
|
||||
|
||||
def test_prune_entire_old_chain(self, db):
|
||||
"""All sessions in a chain are old — entire chain is pruned."""
|
||||
old_ts = time.time() - 200 * 86400
|
||||
|
||||
db.create_session(session_id="X", source="cli")
|
||||
db.end_session("X", end_reason="compressed")
|
||||
db.create_session(session_id="Y", source="cli", parent_session_id="X")
|
||||
db.end_session("Y", end_reason="compressed")
|
||||
db.create_session(session_id="Z", source="cli", parent_session_id="Y")
|
||||
db.end_session("Z", end_reason="done")
|
||||
|
||||
for sid in ("X", "Y", "Z"):
|
||||
db._conn.execute(
|
||||
"UPDATE sessions SET started_at = ? WHERE id = ?", (old_ts, sid)
|
||||
)
|
||||
db._conn.commit()
|
||||
|
||||
pruned = db.prune_sessions(older_than_days=90)
|
||||
assert pruned == 3
|
||||
for sid in ("X", "Y", "Z"):
|
||||
assert db.get_session(sid) is None
|
||||
|
||||
|
||||
class TestDeleteSessionOrphansChildren:
|
||||
def test_delete_orphans_children(self, db):
|
||||
"""Deleting a parent session orphans its children."""
|
||||
db.create_session(session_id="parent", source="cli")
|
||||
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
||||
db.create_session(session_id="grandchild", source="cli", parent_session_id="child")
|
||||
|
||||
# Should not raise IntegrityError
|
||||
result = db.delete_session("parent")
|
||||
assert result is True
|
||||
assert db.get_session("parent") is None
|
||||
# Child is orphaned, not deleted
|
||||
child = db.get_session("child")
|
||||
assert child is not None
|
||||
assert child["parent_session_id"] is None
|
||||
# Grandchild is untouched
|
||||
grandchild = db.get_session("grandchild")
|
||||
assert grandchild is not None
|
||||
assert grandchild["parent_session_id"] == "child"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Schema and WAL mode
|
||||
|
|
|
|||
174
tests/tools/test_base_environment.py
Normal file
174
tests/tools/test_base_environment.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Tests for BaseEnvironment unified execution model.
|
||||
|
||||
Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(),
|
||||
init_session() failure handling, and the CWD marker contract.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _cwd_marker
|
||||
|
||||
|
||||
class _TestableEnv(BaseEnvironment):
|
||||
"""Concrete subclass for testing base class methods."""
|
||||
|
||||
def __init__(self, cwd="/tmp", timeout=10):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None):
|
||||
raise NotImplementedError("Use mock")
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestWrapCommand:
|
||||
def test_basic_shape(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" in wrapped
|
||||
assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped
|
||||
assert "eval 'echo hello'" in wrapped
|
||||
assert "__hermes_ec=$?" in wrapped
|
||||
assert "export -p >" in wrapped
|
||||
assert "pwd -P >" in wrapped
|
||||
assert env._cwd_marker in wrapped
|
||||
assert "exit $__hermes_ec" in wrapped
|
||||
|
||||
def test_no_snapshot_skips_source(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" not in wrapped
|
||||
|
||||
def test_single_quote_escaping(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo 'hello world'", "/tmp")
|
||||
|
||||
assert "eval 'echo '\\''hello world'\\'''" in wrapped
|
||||
|
||||
def test_tilde_not_quoted(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "~")
|
||||
|
||||
assert "cd ~" in wrapped
|
||||
assert "cd '~'" not in wrapped
|
||||
|
||||
def test_cd_failure_exit_126(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "/nonexistent")
|
||||
|
||||
assert "exit 126" in wrapped
|
||||
|
||||
|
||||
class TestExtractCwdFromOutput:
|
||||
def test_happy_path(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/home/user{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/home/user"
|
||||
assert marker not in result["output"]
|
||||
|
||||
def test_missing_marker(self):
|
||||
env = _TestableEnv()
|
||||
result = {"output": "hello world\n"}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/tmp" # unchanged
|
||||
|
||||
def test_marker_in_command_output(self):
|
||||
"""If the marker appears in command output AND as the real marker,
|
||||
rfind grabs the last (real) one."""
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/correct/path"
|
||||
|
||||
def test_output_cleaned(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/tmp{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert "hello" in result["output"]
|
||||
assert marker not in result["output"]
|
||||
|
||||
|
||||
class TestEmbedStdinHeredoc:
|
||||
def test_heredoc_format(self):
|
||||
result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world")
|
||||
|
||||
assert result.startswith("cat << '")
|
||||
assert "hello world" in result
|
||||
assert "HERMES_STDIN_" in result
|
||||
|
||||
def test_unique_delimiter_each_call(self):
|
||||
r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
|
||||
# Extract delimiters
|
||||
d1 = r1.split("'")[1]
|
||||
d2 = r2.split("'")[1]
|
||||
assert d1 != d2 # UUID-based, should be unique
|
||||
|
||||
|
||||
class TestInitSessionFailure:
|
||||
def test_snapshot_ready_false_on_failure(self):
|
||||
env = _TestableEnv()
|
||||
|
||||
def failing_run_bash(*args, **kwargs):
|
||||
raise RuntimeError("bash not found")
|
||||
|
||||
env._run_bash = failing_run_bash
|
||||
env.init_session()
|
||||
|
||||
assert env._snapshot_ready is False
|
||||
|
||||
def test_login_flag_when_snapshot_not_ready(self):
|
||||
"""When _snapshot_ready=False, execute() should pass login=True to _run_bash."""
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
|
||||
calls = []
|
||||
def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None):
|
||||
calls.append({"login": login})
|
||||
# Return a mock process handle
|
||||
mock = MagicMock()
|
||||
mock.poll.return_value = 0
|
||||
mock.returncode = 0
|
||||
mock.stdout = iter([])
|
||||
return mock
|
||||
|
||||
env._run_bash = mock_run_bash
|
||||
env.execute("echo test")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["login"] is True
|
||||
|
||||
|
||||
class TestCwdMarker:
|
||||
def test_marker_contains_session_id(self):
|
||||
env = _TestableEnv()
|
||||
assert env._session_id in env._cwd_marker
|
||||
|
||||
def test_unique_per_instance(self):
|
||||
env1 = _TestableEnv()
|
||||
env2 = _TestableEnv()
|
||||
assert env1._cwd_marker != env2._cwd_marker
|
||||
|
|
@ -59,8 +59,8 @@ def daytona_sdk(monkeypatch):
|
|||
@pytest.fixture()
|
||||
def make_env(daytona_sdk, monkeypatch):
|
||||
"""Factory that creates a DaytonaEnvironment with a mocked SDK."""
|
||||
# Prevent is_interrupted from interfering
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
# Prevent is_interrupted from interfering — patch where it's used (base.py)
|
||||
monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False)
|
||||
# Prevent skills/credential sync from consuming mock exec calls
|
||||
monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: [])
|
||||
monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None)
|
||||
|
|
@ -221,41 +221,45 @@ class TestCleanup:
|
|||
class TestExecute:
|
||||
def test_basic_command(self, make_env):
|
||||
sb = _make_sandbox()
|
||||
# First call: $HOME detection; subsequent calls: actual commands
|
||||
# Calls: (1) $HOME detection, (2) init_session bootstrap, (3) actual command
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="hello", exit_code=0), # actual cmd
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo hello")
|
||||
assert result["output"] == "hello"
|
||||
assert "hello" in result["output"]
|
||||
assert result["returncode"] == 0
|
||||
|
||||
def test_command_wrapped_with_shell_timeout(self, make_env):
|
||||
def test_sdk_timeout_passed_to_exec(self, make_env):
|
||||
"""SDK native timeout is passed to sandbox.process.exec()."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="ok", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb, timeout=42)
|
||||
|
||||
env.execute("echo hello")
|
||||
# The command sent to exec should be wrapped with `timeout N sh -c '...'`
|
||||
# The exec call should receive timeout= kwarg (SDK native timeout)
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
assert call_args[1]["timeout"] == 42
|
||||
# The command should NOT have a shell `timeout` prefix
|
||||
cmd = call_args[0][0]
|
||||
assert cmd.startswith("timeout 42 sh -c ")
|
||||
# SDK timeout param should NOT be passed
|
||||
assert "timeout" not in call_args[1]
|
||||
assert not cmd.startswith("timeout ")
|
||||
|
||||
def test_timeout_returns_exit_code_124(self, make_env):
|
||||
"""Shell timeout utility returns exit code 124."""
|
||||
"""SDK-level timeout surfaces as exit code 124 via _wait_for_process."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=124),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="", exit_code=124), # actual cmd
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
|
|
@ -267,6 +271,7 @@ class TestExecute:
|
|||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="not found", exit_code=127),
|
||||
]
|
||||
sb.state = "started"
|
||||
|
|
@ -279,6 +284,7 @@ class TestExecute:
|
|||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="ok", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
|
|
@ -286,39 +292,47 @@ class TestExecute:
|
|||
|
||||
env.execute("python3", stdin_data="print('hi')")
|
||||
# Check that the command passed to exec contains heredoc markers
|
||||
# (single quotes get shell-escaped by shlex.quote, so check components)
|
||||
# Base class uses HERMES_STDIN_ prefix for heredoc delimiters
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
cmd = call_args[0][0]
|
||||
assert "HERMES_EOF_" in cmd
|
||||
assert "HERMES_STDIN_" in cmd
|
||||
assert "print" in cmd
|
||||
assert "hi" in cmd
|
||||
|
||||
def test_custom_cwd_passed_through(self, make_env):
|
||||
def test_custom_cwd_in_command_wrapper(self, make_env):
|
||||
"""CWD is handled by _wrap_command() in the command string, not as a kwarg."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="/tmp", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
env.execute("pwd", cwd="/tmp")
|
||||
call_kwargs = sb.process.exec.call_args_list[-1][1]
|
||||
assert call_kwargs["cwd"] == "/tmp"
|
||||
# CWD should be embedded in the command string via _wrap_command
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
cmd = call_args[0][0]
|
||||
assert "cd /tmp" in cmd
|
||||
# CWD should NOT be passed as a kwarg to exec
|
||||
assert "cwd" not in call_args[1]
|
||||
|
||||
def test_daytona_error_triggers_retry(self, make_env, daytona_sdk):
|
||||
sb = _make_sandbox()
|
||||
sb.state = "started"
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
daytona_sdk.DaytonaError("transient"), # first attempt fails
|
||||
_make_exec_response(result="ok", exit_code=0), # retry succeeds
|
||||
]
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo retry")
|
||||
assert result["output"] == "ok"
|
||||
assert result["returncode"] == 0
|
||||
# DaytonaError now surfaces directly through _ThreadedProcessHandle
|
||||
# (no retry logic) — the error becomes returncode=1
|
||||
assert result["returncode"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -359,14 +373,18 @@ class TestInterrupt:
|
|||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return _make_exec_response(result="/root") # $HOME detection
|
||||
if calls["n"] == 2:
|
||||
return _make_exec_response(result="", exit_code=0) # init_session
|
||||
event.wait(timeout=5) # simulate long-running command
|
||||
return _make_exec_response(result="done", exit_code=0)
|
||||
|
||||
sb.process.exec.side_effect = exec_side_effect
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
# is_interrupted is checked by base.py's _wait_for_process,
|
||||
# patch where it's actually referenced (base.py's local binding)
|
||||
monkeypatch.setattr(
|
||||
"tools.environments.daytona.is_interrupted", lambda: True
|
||||
"tools.environments.base.is_interrupted", lambda: True
|
||||
)
|
||||
try:
|
||||
result = env.execute("sleep 10")
|
||||
|
|
@ -377,23 +395,24 @@ class TestInterrupt:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry exhaustion
|
||||
# DaytonaError surfaces directly (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRetryExhausted:
|
||||
def test_both_attempts_fail(self, make_env, daytona_sdk):
|
||||
"""DaytonaError surfaces directly as rc=1 (retry logic was removed)."""
|
||||
sb = _make_sandbox()
|
||||
sb.state = "started"
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
daytona_sdk.DaytonaError("fail1"), # first attempt
|
||||
daytona_sdk.DaytonaError("fail2"), # retry
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
daytona_sdk.DaytonaError("fail1"), # actual command fails
|
||||
]
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo x")
|
||||
# Error surfaces directly through _ThreadedProcessHandle (rc=1)
|
||||
assert result["returncode"] == 1
|
||||
assert "Daytona execution error" in result["output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -245,43 +245,42 @@ def _make_execute_only_env(forward_env=None):
|
|||
env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124}
|
||||
env._container_id = "test-container"
|
||||
env._docker_exe = "/usr/bin/docker"
|
||||
# Base class attributes needed by unified execute()
|
||||
env._session_id = "test123"
|
||||
env._snapshot_path = "/tmp/hermes-snap-test123.sh"
|
||||
env._cwd_file = "/tmp/hermes-cwd-test123.txt"
|
||||
env._cwd_marker = "__HERMES_CWD_test123__"
|
||||
env._snapshot_ready = True
|
||||
env._last_sync_time = None
|
||||
env._init_env_args = []
|
||||
return env
|
||||
|
||||
|
||||
def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||
def test_init_env_args_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||
"""_build_init_env_args picks up forwarded env vars from .env file at init time."""
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
result = env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert result["returncode"] == 0
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" in args_str
|
||||
|
||||
|
||||
def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||
def test_init_env_args_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||
"""Shell env vars take priority over .env file values in init env args."""
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_shell" in args_str
|
||||
assert "value_from_dotenv" not in args_str
|
||||
|
||||
|
||||
# ── docker_env tests ──────────────────────────────────────────────
|
||||
|
|
@ -302,64 +301,46 @@ def test_docker_env_appears_in_run_command(monkeypatch):
|
|||
assert "GNUPGHOME=/root/.gnupg" in run_args_str
|
||||
|
||||
|
||||
def test_docker_env_appears_in_exec_command(monkeypatch):
|
||||
"""Explicit docker_env values should also be passed via -e at docker exec time."""
|
||||
def test_docker_env_appears_in_init_env_args(monkeypatch):
|
||||
"""Explicit docker_env values should appear in _build_init_env_args."""
|
||||
env = _make_execute_only_env()
|
||||
env._env = {"MY_VAR": "my_value"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
|
||||
assert popen_calls, "Popen should have been called"
|
||||
assert "MY_VAR=my_value" in popen_calls[0]
|
||||
assert "MY_VAR=my_value" in args_str
|
||||
|
||||
|
||||
def test_forward_env_overrides_docker_env(monkeypatch):
|
||||
def test_forward_env_overrides_docker_env_in_init_args(monkeypatch):
|
||||
"""docker_forward_env should override docker_env for the same key."""
|
||||
env = _make_execute_only_env(forward_env=["MY_KEY"])
|
||||
env._env = {"MY_KEY": "static_value"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("MY_KEY", "dynamic_value")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
cmd_str = " ".join(popen_calls[0])
|
||||
assert "MY_KEY=dynamic_value" in cmd_str
|
||||
assert "MY_KEY=static_value" not in cmd_str
|
||||
assert "MY_KEY=dynamic_value" in args_str
|
||||
assert "MY_KEY=static_value" not in args_str
|
||||
|
||||
|
||||
def test_docker_env_and_forward_env_merge(monkeypatch):
|
||||
def test_docker_env_and_forward_env_merge_in_init_args(monkeypatch):
|
||||
"""docker_env and docker_forward_env with different keys should both appear."""
|
||||
env = _make_execute_only_env(forward_env=["TOKEN"])
|
||||
env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("TOKEN", "secret123")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in args_str
|
||||
assert "TOKEN=secret123" in args_str
|
||||
|
||||
cmd_str = " ".join(popen_calls[0])
|
||||
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in cmd_str
|
||||
assert "TOKEN=secret123" in cmd_str
|
||||
|
||||
|
||||
def test_normalize_env_dict_filters_invalid_keys():
|
||||
|
|
|
|||
|
|
@ -22,21 +22,19 @@ import pytest
|
|||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from tools.environments.local import (
|
||||
LocalEnvironment,
|
||||
_clean_shell_noise,
|
||||
_extract_fenced_output,
|
||||
_OUTPUT_FENCE,
|
||||
_SHELL_NOISE_SUBSTRINGS,
|
||||
)
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
|
||||
# ── Shared noise detection ───────────────────────────────────────────────
|
||||
# Every known shell noise pattern. If ANY of these appear in output that
|
||||
# isn't explicitly expected, the test fails with a clear message.
|
||||
# Known shell noise patterns that should never appear in command output.
|
||||
|
||||
_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [
|
||||
_ALL_NOISE_PATTERNS = [
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
"bash: ",
|
||||
"Inappropriate ioctl",
|
||||
"Auto-suggestions:",
|
||||
|
|
@ -88,134 +86,6 @@ def populated_dir(tmp_path):
|
|||
return tmp_path
|
||||
|
||||
|
||||
# ── _clean_shell_noise unit tests ────────────────────────────────────────
|
||||
|
||||
class TestCleanShellNoise:
|
||||
def test_single_noise_line(self):
|
||||
output = "bash: no job control in this shell\nhello world\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello world\n"
|
||||
|
||||
def test_double_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"actual output here\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output here\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_tcsetattr_noise(self):
|
||||
output = (
|
||||
"bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"real content\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "real content\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_triple_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"clean\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "clean\n"
|
||||
|
||||
def test_no_noise_untouched(self):
|
||||
assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _clean_shell_noise("") == ""
|
||||
|
||||
def test_only_noise_produces_empty(self):
|
||||
output = "bash: no job control in this shell\n"
|
||||
result = _clean_shell_noise(output)
|
||||
_assert_clean(result)
|
||||
|
||||
def test_noise_in_middle_not_stripped(self):
|
||||
"""Noise in the middle is real output and should be preserved."""
|
||||
output = "real\nbash: no job control in this shell\nmore real\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == output
|
||||
|
||||
def test_zsh_restored_session(self):
|
||||
output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_saving_session_trailing(self):
|
||||
output = "hello\nSaving session...completed.\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_oh_my_zsh_banner(self):
|
||||
output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_full_noise_sandwich(self):
|
||||
"""Both leading and trailing zsh noise stripped."""
|
||||
output = (
|
||||
"Restored session: Mon Mar 2\n"
|
||||
"command not found: docker\n"
|
||||
"Oh My Zsh on!\n"
|
||||
"actual output\n"
|
||||
"Saving session...completed.\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output\n"
|
||||
|
||||
def test_last_login_stripped(self):
|
||||
output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
|
||||
# ── _extract_fenced_output unit tests ────────────────────────────────────
|
||||
|
||||
class TestExtractFencedOutput:
|
||||
def test_normal_fenced_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n"
|
||||
assert _extract_fenced_output(raw) == "hello world\n"
|
||||
|
||||
def test_no_trailing_newline(self):
|
||||
"""printf output with no trailing newline is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == "exact"
|
||||
|
||||
def test_no_fences_falls_back(self):
|
||||
"""Without fences, falls back to pattern-based cleaning."""
|
||||
raw = "bash: no job control in this shell\nhello\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_only_start_fence(self):
|
||||
"""Only start fence (e.g. user command called exit)."""
|
||||
raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_user_outputs_fence_string(self):
|
||||
"""If user command outputs the fence marker, it is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise"
|
||||
result = _extract_fenced_output(raw)
|
||||
# first fence -> last fence captures the middle including user's fence
|
||||
assert _OUTPUT_FENCE in result
|
||||
assert "real\n" in result
|
||||
|
||||
def test_empty_command_output(self):
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == ""
|
||||
|
||||
def test_multiline_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n"
|
||||
assert _extract_fenced_output(raw) == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
# ── LocalEnvironment.execute() ───────────────────────────────────────────
|
||||
|
||||
class TestLocalEnvironmentExecute:
|
||||
|
|
|
|||
|
|
@ -1,164 +0,0 @@
|
|||
"""Tests for the local persistent shell backend."""
|
||||
|
||||
import glob as glob_mod
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
def test_local_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is False
|
||||
|
||||
def test_local_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
def test_local_persistent_yes(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
|
||||
class TestMergeOutput:
|
||||
def test_stdout_only(self):
|
||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
||||
|
||||
def test_stderr_only(self):
|
||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
||||
|
||||
def test_both(self):
|
||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
||||
|
||||
def test_empty(self):
|
||||
assert PersistentShellMixin._merge_output("", "") == ""
|
||||
|
||||
def test_strips_trailing_newlines(self):
|
||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
||||
|
||||
|
||||
class TestLocalOneShotRegression:
|
||||
def test_echo(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("echo hello")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello" in r["output"]
|
||||
env.cleanup()
|
||||
|
||||
def test_exit_code(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("exit 42")
|
||||
assert r["returncode"] == 42
|
||||
env.cleanup()
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
||||
assert r["output"].strip() == ""
|
||||
env.cleanup()
|
||||
|
||||
def test_oneshot_heredoc_does_not_leak_fence_wrapper(self):
|
||||
"""Heredoc closing line must not be merged with the fence wrapper tail."""
|
||||
env = LocalEnvironment(persistent=False)
|
||||
cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF"
|
||||
r = env.execute(cmd)
|
||||
env.cleanup()
|
||||
assert r["returncode"] == 0
|
||||
assert "heredoc body line" in r["output"]
|
||||
assert "__hermes_rc" not in r["output"]
|
||||
assert "printf '" not in r["output"]
|
||||
assert "exit $" not in r["output"]
|
||||
|
||||
|
||||
class TestLocalPersistent:
|
||||
@pytest.fixture
|
||||
def env(self):
|
||||
e = LocalEnvironment(persistent=True)
|
||||
yield e
|
||||
e.cleanup()
|
||||
|
||||
def test_echo(self, env):
|
||||
r = env.execute("echo hello-persistent")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self, env):
|
||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self, env):
|
||||
env.execute("cd /tmp")
|
||||
r = env.execute("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self, env):
|
||||
r = env.execute("(exit 42)")
|
||||
assert r["returncode"] == 42
|
||||
|
||||
def test_stderr(self, env):
|
||||
r = env.execute("echo oops >&2")
|
||||
assert r["returncode"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self, env):
|
||||
r = env.execute("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self, env):
|
||||
r = env.execute("sleep 999", timeout=2)
|
||||
assert r["returncode"] in (124, 130)
|
||||
r = env.execute("echo alive")
|
||||
assert r["returncode"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self, env):
|
||||
r = env.execute("seq 1 1000")
|
||||
assert r["returncode"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
|
||||
def test_shell_variable_persists(self, env):
|
||||
env.execute("MY_LOCAL_VAR=hello123")
|
||||
r = env.execute("echo $MY_LOCAL_VAR")
|
||||
assert r["output"].strip() == "hello123"
|
||||
|
||||
def test_cleanup_removes_temp_files(self, env):
|
||||
env.execute("echo warmup")
|
||||
prefix = env._temp_prefix
|
||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
||||
env.cleanup()
|
||||
remaining = glob_mod.glob(f"{prefix}-*")
|
||||
assert remaining == []
|
||||
|
||||
def test_state_does_not_leak_between_instances(self):
|
||||
env1 = LocalEnvironment(persistent=True)
|
||||
env2 = LocalEnvironment(persistent=True)
|
||||
try:
|
||||
env1.execute("export LEAK_TEST=from_env1")
|
||||
r = env2.execute("echo $LEAK_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
finally:
|
||||
env1.cleanup()
|
||||
env2.cleanup()
|
||||
|
||||
def test_special_characters_in_command(self, env):
|
||||
r = env.execute("echo 'hello world'")
|
||||
assert r["output"].strip() == "hello world"
|
||||
|
||||
def test_pipe_command(self, env):
|
||||
r = env.execute("echo hello | tr 'h' 'H'")
|
||||
assert r["output"].strip() == "Hello"
|
||||
|
||||
def test_multiple_commands_semicolon(self, env):
|
||||
r = env.execute("X=42; echo $X")
|
||||
assert r["output"].strip() == "42"
|
||||
|
|
@ -110,7 +110,7 @@ class _FakeResponse:
|
|||
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
poll_count = {"value": 0}
|
||||
|
|
@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
|
|||
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
interrupt_event = _install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
|
||||
|
|
@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
|||
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
|
|
@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough():
|
|||
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
|
|
|
|||
|
|
@ -231,20 +231,20 @@ class TestEnsurepipFix:
|
|||
"""Verify the pip fix is applied in the ModalEnvironment init."""
|
||||
|
||||
def test_modal_environment_creates_image_with_setup_commands(self):
|
||||
"""ModalEnvironment.__init__ should create a modal.Image with pip fix."""
|
||||
"""_resolve_modal_image should create a modal.Image with pip fix."""
|
||||
try:
|
||||
from tools.environments.modal import ModalEnvironment
|
||||
from tools.environments.modal import _resolve_modal_image
|
||||
except ImportError:
|
||||
pytest.skip("tools.environments.modal not importable")
|
||||
|
||||
import inspect
|
||||
source = inspect.getsource(ModalEnvironment.__init__)
|
||||
source = inspect.getsource(_resolve_modal_image)
|
||||
assert "ensurepip" in source, (
|
||||
"ModalEnvironment should include ensurepip fix "
|
||||
"_resolve_modal_image should include ensurepip fix "
|
||||
"for Modal's legacy image builder"
|
||||
)
|
||||
assert "setup_dockerfile_commands" in source, (
|
||||
"ModalEnvironment should use setup_dockerfile_commands "
|
||||
"_resolve_modal_image should use setup_dockerfile_commands "
|
||||
"to fix pip before Modal's bootstrap"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,11 +85,47 @@ def _install_modal_test_modules(
|
|||
def _prepare_command(self, command: str):
|
||||
return command, None
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
|
||||
def init_session(self):
|
||||
pass
|
||||
|
||||
# Stub _ThreadedProcessHandle: modal.py imports it but only uses it at
|
||||
# runtime inside _run_bash; the snapshot-isolation tests never call _run_bash,
|
||||
# so a class placeholder is sufficient.
|
||||
class _DummyThreadedProcessHandle:
|
||||
def __init__(self, exec_fn, cancel_fn=None):
|
||||
pass
|
||||
|
||||
def _load_json_store(path):
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
def _save_json_store(path, data):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
def _file_mtime_key(host_path):
|
||||
try:
|
||||
st = Path(host_path).stat()
|
||||
return (st.st_mtime, st.st_size)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(
|
||||
BaseEnvironment=_DummyBaseEnvironment,
|
||||
_ThreadedProcessHandle=_DummyThreadedProcessHandle,
|
||||
_load_json_store=_load_json_store,
|
||||
_save_json_store=_save_json_store,
|
||||
_file_mtime_key=_file_mtime_key,
|
||||
)
|
||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: [],
|
||||
iter_skills_files=lambda: [],
|
||||
iter_cache_files=lambda: [],
|
||||
)
|
||||
|
||||
from_id_calls: list[str] = []
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class TestBuildSSHCommand:
|
|||
lambda *a, **k: MagicMock(stdout=iter([]),
|
||||
stderr=iter([]),
|
||||
stdin=MagicMock()))
|
||||
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
|
||||
monkeypatch.setattr("tools.environments.base.time.sleep", lambda _: None)
|
||||
|
||||
def test_base_flags(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
|
|
|||
21
tests/tools/test_terminal_none_command_guard.py
Normal file
21
tests/tools/test_terminal_none_command_guard.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Regression tests for invalid/None terminal command handling."""
|
||||
|
||||
import json
|
||||
|
||||
from tools.terminal_tool import _transform_sudo_command, terminal_tool
|
||||
|
||||
|
||||
def test_transform_sudo_command_none_returns_cleanly():
|
||||
transformed, sudo_stdin = _transform_sudo_command(None)
|
||||
|
||||
assert transformed is None
|
||||
assert sudo_stdin is None
|
||||
|
||||
|
||||
def test_terminal_tool_none_command_returns_clean_error():
|
||||
result = json.loads(terminal_tool(None)) # type: ignore[arg-type]
|
||||
|
||||
assert result["exit_code"] == -1
|
||||
assert result["status"] == "error"
|
||||
assert "expected string" in result["error"].lower()
|
||||
assert "nonetype" in result["error"].lower()
|
||||
90
tests/tools/test_terminal_tool.py
Normal file
90
tests/tools/test_terminal_tool.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Regression tests for sudo detection and sudo password handling."""
|
||||
|
||||
import tools.terminal_tool as terminal_tool
|
||||
|
||||
|
||||
def setup_function():
|
||||
terminal_tool._cached_sudo_password = ""
|
||||
|
||||
|
||||
def teardown_function():
|
||||
terminal_tool._cached_sudo_password = ""
|
||||
|
||||
|
||||
def test_searching_for_sudo_does_not_trigger_rewrite(monkeypatch):
|
||||
monkeypatch.delenv("SUDO_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
command = "rg --line-number --no-heading --with-filename 'sudo' . | head -n 20"
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command(command)
|
||||
|
||||
assert transformed == command
|
||||
assert sudo_stdin is None
|
||||
|
||||
|
||||
def test_printf_literal_sudo_does_not_trigger_rewrite(monkeypatch):
|
||||
monkeypatch.delenv("SUDO_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
command = "printf '%s\\n' sudo"
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command(command)
|
||||
|
||||
assert transformed == command
|
||||
assert sudo_stdin is None
|
||||
|
||||
|
||||
def test_non_command_argument_named_sudo_does_not_trigger_rewrite(monkeypatch):
|
||||
monkeypatch.delenv("SUDO_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
command = "grep -n sudo README.md"
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command(command)
|
||||
|
||||
assert transformed == command
|
||||
assert sudo_stdin is None
|
||||
|
||||
|
||||
def test_actual_sudo_command_uses_configured_password(monkeypatch):
|
||||
monkeypatch.setenv("SUDO_PASSWORD", "testpass")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo apt install -y ripgrep")
|
||||
|
||||
assert transformed == "sudo -S -p '' apt install -y ripgrep"
|
||||
assert sudo_stdin == "testpass\n"
|
||||
|
||||
|
||||
def test_actual_sudo_after_leading_env_assignment_is_rewritten(monkeypatch):
|
||||
monkeypatch.setenv("SUDO_PASSWORD", "testpass")
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command("DEBUG=1 sudo whoami")
|
||||
|
||||
assert transformed == "DEBUG=1 sudo -S -p '' whoami"
|
||||
assert sudo_stdin == "testpass\n"
|
||||
|
||||
|
||||
def test_explicit_empty_sudo_password_tries_empty_without_prompt(monkeypatch):
|
||||
monkeypatch.setenv("SUDO_PASSWORD", "")
|
||||
monkeypatch.setenv("HERMES_INTERACTIVE", "1")
|
||||
|
||||
def _fail_prompt(*_args, **_kwargs):
|
||||
raise AssertionError("interactive sudo prompt should not run for explicit empty password")
|
||||
|
||||
monkeypatch.setattr(terminal_tool, "_prompt_for_sudo_password", _fail_prompt)
|
||||
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command("sudo true")
|
||||
|
||||
assert transformed == "sudo -S -p '' true"
|
||||
assert sudo_stdin == "\n"
|
||||
|
||||
|
||||
def test_cached_sudo_password_is_used_when_env_is_unset(monkeypatch):
|
||||
monkeypatch.delenv("SUDO_PASSWORD", raising=False)
|
||||
monkeypatch.delenv("HERMES_INTERACTIVE", raising=False)
|
||||
terminal_tool._cached_sudo_password = "cached-pass"
|
||||
|
||||
transformed, sudo_stdin = terminal_tool._transform_sudo_command("echo ok && sudo whoami")
|
||||
|
||||
assert transformed == "echo ok && sudo -S -p '' whoami"
|
||||
assert sudo_stdin == "cached-pass\n"
|
||||
144
tests/tools/test_threaded_process_handle.py
Normal file
144
tests/tools/test_threaded_process_handle.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Tests for _ThreadedProcessHandle — the adapter for SDK backends."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import _ThreadedProcessHandle
|
||||
|
||||
|
||||
class TestBasicExecution:
|
||||
def test_successful_execution(self):
|
||||
def exec_fn():
|
||||
return ("hello world", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert "hello world" in output
|
||||
|
||||
def test_nonzero_exit_code(self):
|
||||
def exec_fn():
|
||||
return ("error occurred", 42)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 42
|
||||
output = handle.stdout.read()
|
||||
assert "error occurred" in output
|
||||
|
||||
def test_exception_in_exec_fn(self):
|
||||
def exec_fn():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 1
|
||||
|
||||
def test_empty_output(self):
|
||||
def exec_fn():
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert output == ""
|
||||
|
||||
|
||||
class TestPolling:
|
||||
def test_poll_returns_none_while_running(self):
|
||||
event = threading.Event()
|
||||
|
||||
def exec_fn():
|
||||
event.wait(timeout=5)
|
||||
return ("done", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
assert handle.poll() is None
|
||||
|
||||
event.set()
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
def test_poll_returns_returncode_when_done(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
|
||||
class TestCancelFn:
|
||||
def test_cancel_fn_called_on_kill(self):
|
||||
called = threading.Event()
|
||||
|
||||
def cancel():
|
||||
called.set()
|
||||
|
||||
def exec_fn():
|
||||
time.sleep(10)
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill()
|
||||
assert called.is_set()
|
||||
|
||||
def test_cancel_fn_none_is_safe(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None)
|
||||
handle.kill() # should not raise
|
||||
handle.wait(timeout=5)
|
||||
assert handle.returncode == 0
|
||||
|
||||
def test_cancel_fn_exception_swallowed(self):
|
||||
def cancel():
|
||||
raise RuntimeError("cancel failed")
|
||||
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill() # should not raise despite cancel raising
|
||||
handle.wait(timeout=5)
|
||||
|
||||
|
||||
class TestStdoutPipe:
|
||||
def test_stdout_is_readable(self):
|
||||
def exec_fn():
|
||||
return ("line1\nline2\nline3\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
lines = handle.stdout.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == "line1\n"
|
||||
|
||||
def test_stdout_iterable(self):
|
||||
def exec_fn():
|
||||
return ("a\nb\nc\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
collected = list(handle.stdout)
|
||||
assert len(collected) == 3
|
||||
|
||||
def test_unicode_output(self):
|
||||
def exec_fn():
|
||||
return ("hello 世界 🌍\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
output = handle.stdout.read()
|
||||
assert "世界" in output
|
||||
assert "🌍" in output
|
||||
|
|
@ -18,7 +18,7 @@ Architecture (two transports):
|
|||
2. Parent ships both files to the remote environment
|
||||
3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.)
|
||||
4. Tool calls are written as request files; a polling thread on the parent
|
||||
reads them via execute_oneshot(), dispatches, and writes response files
|
||||
reads them via env.execute(), dispatches, and writes response files
|
||||
5. The script polls for response files and continues
|
||||
|
||||
In both cases, only the script's stdout is returned to the LLM; intermediate
|
||||
|
|
@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None:
|
|||
quotes are fine.
|
||||
"""
|
||||
encoded = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded}' | base64 -d > {remote_path}",
|
||||
cwd="/",
|
||||
timeout=30,
|
||||
|
|
@ -555,9 +555,9 @@ def _rpc_poll_loop(
|
|||
):
|
||||
"""Poll the remote filesystem for tool call requests and dispatch them.
|
||||
|
||||
Runs in a background thread. Uses ``env.execute_oneshot()`` so it can
|
||||
operate concurrently with the script-execution thread that holds
|
||||
``env.execute()`` (important for persistent-shell backends like SSH).
|
||||
Runs in a background thread. Each ``env.execute()`` spawns an
|
||||
independent process, so these calls run safely concurrent with the
|
||||
script-execution thread.
|
||||
"""
|
||||
from model_tools import handle_function_call
|
||||
|
||||
|
|
@ -566,7 +566,7 @@ def _rpc_poll_loop(
|
|||
while not stop_event.is_set():
|
||||
try:
|
||||
# List pending request files (skip .tmp partials)
|
||||
ls_result = env.execute_oneshot(
|
||||
ls_result = env.execute(
|
||||
f"ls -1 {rpc_dir}/req_* 2>/dev/null || true",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
|
|
@ -590,7 +590,7 @@ def _rpc_poll_loop(
|
|||
call_start = time.monotonic()
|
||||
|
||||
# Read request
|
||||
read_result = env.execute_oneshot(
|
||||
read_result = env.execute(
|
||||
f"cat {req_file}",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
|
|
@ -600,7 +600,7 @@ def _rpc_poll_loop(
|
|||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed RPC request in %s", req_file)
|
||||
# Remove bad request to avoid infinite retry
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
continue
|
||||
|
||||
tool_name = request.get("tool", "")
|
||||
|
|
@ -664,7 +664,7 @@ def _rpc_poll_loop(
|
|||
encoded_result = base64.b64encode(
|
||||
tool_result.encode("utf-8")
|
||||
).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded_result}' | base64 -d > {res_file}.tmp"
|
||||
f" && mv {res_file}.tmp {res_file}",
|
||||
cwd="/",
|
||||
|
|
@ -672,7 +672,7 @@ def _rpc_poll_loop(
|
|||
)
|
||||
|
||||
# Remove the request file
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
|
||||
except Exception as e:
|
||||
if not stop_event.is_set():
|
||||
|
|
@ -717,7 +717,7 @@ def _execute_remote(
|
|||
|
||||
try:
|
||||
# Verify Python is available on the remote
|
||||
py_check = env.execute_oneshot(
|
||||
py_check = env.execute(
|
||||
"command -v python3 >/dev/null 2>&1 && echo OK",
|
||||
cwd="/", timeout=15,
|
||||
)
|
||||
|
|
@ -734,7 +734,7 @@ def _execute_remote(
|
|||
})
|
||||
|
||||
# Create sandbox directory on remote
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10,
|
||||
)
|
||||
|
||||
|
|
@ -806,7 +806,7 @@ def _execute_remote(
|
|||
|
||||
# Clean up remote sandbox dir
|
||||
try:
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"rm -rf {sandbox_dir}", cwd="/", timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -455,7 +455,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr
|
|||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'"
|
||||
"description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, bluebubbles, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,27 @@
|
|||
"""Base class for all Hermes execution environment backends."""
|
||||
"""Base class for all Hermes execution environment backends.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process.
|
||||
A session snapshot (env vars, functions, aliases) is captured once at init and
|
||||
re-sourced before each command. CWD persists via in-band stdout markers (remote)
|
||||
or a temp file (local).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import IO, Callable, Protocol
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_sandbox_dir() -> Path:
|
||||
|
|
@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path:
|
|||
return p
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface for all Hermes execution backends.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants and utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Subclasses implement execute() and cleanup(). Shared helpers eliminate
|
||||
duplicated subprocess boilerplate across backends.
|
||||
_SYNC_INTERVAL_SECONDS = 5.0
|
||||
|
||||
|
||||
def _pipe_stdin(proc: subprocess.Popen, data: str) -> None:
|
||||
"""Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks."""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
proc.stdin.write(data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write, daemon=True).start()
|
||||
|
||||
|
||||
def _popen_bash(
|
||||
cmd: list[str], stdin_data: str | None = None, **kwargs
|
||||
) -> subprocess.Popen:
|
||||
"""Spawn a subprocess with standard stdout/stderr/stdin setup.
|
||||
|
||||
If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`.
|
||||
Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass
|
||||
this and call :func:`_pipe_stdin` directly.
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
text=True,
|
||||
**kwargs,
|
||||
)
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
return proc
|
||||
|
||||
|
||||
def _load_json_store(path: Path) -> dict:
|
||||
"""Load a JSON file as a dict, returning ``{}`` on any error."""
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_json_store(path: Path, data: dict) -> None:
|
||||
"""Write *data* as pretty-printed JSON to *path*."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def _file_mtime_key(host_path: str) -> tuple[float, int] | None:
|
||||
"""Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable."""
|
||||
try:
|
||||
st = Path(host_path).stat()
|
||||
return (st.st_mtime, st.st_size)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProcessHandle protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProcessHandle(Protocol):
|
||||
"""Duck type that every backend's _run_bash() must return.
|
||||
|
||||
subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona)
|
||||
return _ThreadedProcessHandle which adapts their blocking calls.
|
||||
"""
|
||||
|
||||
def poll(self) -> int | None: ...
|
||||
def kill(self) -> None: ...
|
||||
def wait(self, timeout: float | None = None) -> int: ...
|
||||
|
||||
@property
|
||||
def stdout(self) -> IO[str] | None: ...
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None: ...
|
||||
|
||||
|
||||
class _ThreadedProcessHandle:
|
||||
"""Adapter for SDK backends (Modal, Daytona) that have no real subprocess.
|
||||
|
||||
Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background
|
||||
thread and exposes a ProcessHandle-compatible interface. An optional
|
||||
``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation
|
||||
(e.g. Modal sandbox.terminate, Daytona sandbox.stop).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exec_fn: Callable[[], tuple[str, int]],
|
||||
cancel_fn: Callable[[], None] | None = None,
|
||||
):
|
||||
self._cancel_fn = cancel_fn
|
||||
self._done = threading.Event()
|
||||
self._returncode: int | None = None
|
||||
self._error: Exception | None = None
|
||||
|
||||
# Pipe for stdout — drain thread in _wait_for_process reads the read end.
|
||||
read_fd, write_fd = os.pipe()
|
||||
self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace")
|
||||
self._write_fd = write_fd
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
output, exit_code = exec_fn()
|
||||
self._returncode = exit_code
|
||||
# Write output into the pipe so drain thread picks it up.
|
||||
try:
|
||||
os.write(self._write_fd, output.encode("utf-8", errors="replace"))
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._returncode = 1
|
||||
finally:
|
||||
try:
|
||||
os.close(self._write_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self._done.set()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
|
||||
@property
|
||||
def stdout(self):
|
||||
return self._stdout
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None:
|
||||
return self._returncode
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return self._returncode if self._done.is_set() else None
|
||||
|
||||
def kill(self):
|
||||
if self._cancel_fn:
|
||||
try:
|
||||
self._cancel_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
self._done.wait(timeout=timeout)
|
||||
return self._returncode
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CWD marker for remote backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cwd_marker(session_id: str) -> str:
|
||||
return f"__HERMES_CWD_{session_id}__"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseEnvironment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface and unified execution flow for all Hermes backends.
|
||||
|
||||
Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class
|
||||
provides ``execute()`` with session snapshot sourcing, CWD tracking,
|
||||
interrupt handling, and timeout enforcement.
|
||||
"""
|
||||
|
||||
# Subclasses that embed stdin as a heredoc (Modal, Daytona) set this.
|
||||
_stdin_mode: str = "pipe" # "pipe" or "heredoc"
|
||||
|
||||
# Snapshot creation timeout (override for slow cold-starts).
|
||||
_snapshot_timeout: int = 30
|
||||
|
||||
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh"
|
||||
self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt"
|
||||
self._cwd_marker = _cwd_marker(self._session_id)
|
||||
self._snapshot_ready = False
|
||||
self._last_sync_time: float | None = (
|
||||
None # set to 0 by backends that need file sync
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_bash(
|
||||
self,
|
||||
cmd_string: str,
|
||||
*,
|
||||
login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None,
|
||||
) -> ProcessHandle:
|
||||
"""Spawn a bash process to run *cmd_string*.
|
||||
|
||||
Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle).
|
||||
Must be overridden by every backend.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()")
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""Release backend resources (container, instance, connection)."""
|
||||
...
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session snapshot (init_session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def init_session(self):
|
||||
"""Capture login shell environment into a snapshot file.
|
||||
|
||||
Called once after backend construction. On success, sets
|
||||
``_snapshot_ready = True`` so subsequent commands source the snapshot
|
||||
instead of running with ``bash -l``.
|
||||
"""
|
||||
# Full capture: env vars, functions (filtered), aliases, shell options.
|
||||
bootstrap = (
|
||||
f"export -p > {self._snapshot_path}\n"
|
||||
f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n"
|
||||
f"alias -p >> {self._snapshot_path}\n"
|
||||
f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +e' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +u' >> {self._snapshot_path}\n"
|
||||
f"pwd -P > {self._cwd_file} 2>/dev/null || true\n"
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n"
|
||||
)
|
||||
try:
|
||||
proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout)
|
||||
result = self._wait_for_process(proc, timeout=self._snapshot_timeout)
|
||||
self._snapshot_ready = True
|
||||
self._update_cwd(result)
|
||||
logger.info(
|
||||
"Session snapshot created (session=%s, cwd=%s)",
|
||||
self._session_id,
|
||||
self.cwd,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"init_session failed (session=%s): %s — "
|
||||
"falling back to bash -l per command",
|
||||
self._session_id,
|
||||
exc,
|
||||
)
|
||||
self._snapshot_ready = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command wrapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wrap_command(self, command: str, cwd: str) -> str:
|
||||
"""Build the full bash script that sources snapshot, cd's, runs command,
|
||||
re-dumps env vars, and emits CWD markers."""
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
parts = []
|
||||
|
||||
# Source snapshot (env vars from previous commands)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"source {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# cd to working directory — let bash expand ~ natively
|
||||
quoted_cwd = (
|
||||
shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd
|
||||
)
|
||||
parts.append(f"cd {quoted_cwd} || exit 126")
|
||||
|
||||
# Run the actual command
|
||||
parts.append(f"eval '{escaped}'")
|
||||
parts.append("__hermes_ec=$?")
|
||||
|
||||
# Re-dump env vars to snapshot (last-writer-wins for concurrent calls)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# Write CWD to file (local reads this) and stdout marker (remote parses this)
|
||||
parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true")
|
||||
# Use a distinct line for the marker. The leading \n ensures
|
||||
# the marker starts on its own line even if the command doesn't
|
||||
# end with a newline (e.g. printf 'exact'). We'll strip this
|
||||
# injected newline in _extract_cwd_from_output.
|
||||
parts.append(
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\""
|
||||
)
|
||||
parts.append("exit $__hermes_ec")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stdin heredoc embedding (for SDK backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _embed_stdin_heredoc(command: str, stdin_data: str) -> str:
|
||||
"""Append stdin_data as a shell heredoc to the command string."""
|
||||
delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}"
|
||||
return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Process lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict:
|
||||
"""Poll-based wait with interrupt checking and stdout draining.
|
||||
|
||||
Shared across all backends — not overridden.
|
||||
"""
|
||||
output_chunks: list[str] = []
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
output_chunks.append(line)
|
||||
except UnicodeDecodeError:
|
||||
output_chunks.clear()
|
||||
output_chunks.append(
|
||||
"[binary output detected — raw bytes not displayable]"
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
drain_thread = threading.Thread(target=_drain, daemon=True)
|
||||
drain_thread.start()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
partial = "".join(output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg
|
||||
if partial
|
||||
else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
|
||||
drain_thread.join(timeout=5)
|
||||
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"output": "".join(output_chunks), "returncode": proc.returncode}
|
||||
|
||||
def _kill_process(self, proc: ProcessHandle):
|
||||
"""Terminate a process. Subclasses may override for process-group kill."""
|
||||
try:
|
||||
proc.kill()
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CWD extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Extract CWD from command output. Override for local file-based read."""
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
def _extract_cwd_from_output(self, result: dict):
|
||||
"""Parse the __HERMES_CWD_{session}__ marker from stdout output.
|
||||
|
||||
Updates self.cwd and strips the marker from result["output"].
|
||||
Used by remote backends (Docker, SSH, Modal, Daytona, Singularity).
|
||||
"""
|
||||
output = result.get("output", "")
|
||||
marker = self._cwd_marker
|
||||
last = output.rfind(marker)
|
||||
if last == -1:
|
||||
return
|
||||
|
||||
# Find the opening marker before this closing one
|
||||
search_start = max(0, last - 4096) # CWD path won't be >4KB
|
||||
first = output.rfind(marker, search_start, last)
|
||||
if first == -1 or first == last:
|
||||
return
|
||||
|
||||
cwd_path = output[first + len(marker) : last].strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
|
||||
# Strip the marker line AND the \n we injected before it.
|
||||
# The wrapper emits: printf '\n__MARKER__%s__MARKER__\n'
|
||||
# So the output looks like: <cmd output>\n__MARKER__path__MARKER__\n
|
||||
# We want to remove everything from the injected \n onwards.
|
||||
line_start = output.rfind("\n", 0, first)
|
||||
if line_start == -1:
|
||||
line_start = first
|
||||
line_end = output.find("\n", last + len(marker))
|
||||
line_end = line_end + 1 if line_end != -1 else len(output)
|
||||
|
||||
result["output"] = output[:line_start] + output[line_end:]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _before_execute(self):
|
||||
"""Rate-limited file sync before each command.
|
||||
|
||||
Backends that need pre-command sync set ``self._last_sync_time = 0``
|
||||
in ``__init__`` and override :meth:`_sync_files`. Backends needing
|
||||
extra pre-exec logic (e.g. Daytona sandbox restart check) override
|
||||
this method and call ``super()._before_execute()``.
|
||||
"""
|
||||
if self._last_sync_time is not None:
|
||||
now = time.monotonic()
|
||||
if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS:
|
||||
self._sync_files()
|
||||
self._last_sync_time = now
|
||||
|
||||
def _sync_files(self):
|
||||
"""Push files to remote environment. Called rate-limited by _before_execute."""
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unified execute()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = "",
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
self._before_execute()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
effective_cwd = cwd or self.cwd
|
||||
|
||||
# Merge sudo stdin with caller stdin
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# Embed stdin as heredoc for backends that need it
|
||||
if effective_stdin and self._stdin_mode == "heredoc":
|
||||
exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin)
|
||||
effective_stdin = None
|
||||
|
||||
wrapped = self._wrap_command(exec_command, effective_cwd)
|
||||
|
||||
# Use login shell if snapshot failed (so user's profile still loads)
|
||||
login = not self._snapshot_ready
|
||||
|
||||
proc = self._run_bash(
|
||||
wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin
|
||||
)
|
||||
result = self._wait_for_process(proc, timeout=effective_timeout)
|
||||
self._update_cwd(result)
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def stop(self):
|
||||
"""Alias for cleanup (compat with older callers)."""
|
||||
self.cleanup()
|
||||
|
|
@ -57,53 +544,12 @@ class BaseEnvironment(ABC):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers (eliminate duplication across backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available.
|
||||
|
||||
Returns:
|
||||
(transformed_command, sudo_stdin) — see _transform_sudo_command
|
||||
for the full contract. Callers that drive a subprocess directly
|
||||
should prepend sudo_stdin (when not None) to any stdin_data they
|
||||
pass to Popen. Callers that embed stdin via heredoc (modal,
|
||||
daytona) handle sudo_stdin in their own execute() method.
|
||||
"""
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
"timeout": timeout or self.timeout,
|
||||
"encoding": "utf-8",
|
||||
"errors": "replace",
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.STDOUT,
|
||||
}
|
||||
if stdin_data is not None:
|
||||
kw["input"] = stdin_data
|
||||
else:
|
||||
kw["stdin"] = subprocess.DEVNULL
|
||||
return kw
|
||||
|
||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command bypassing any persistent shell.
|
||||
|
||||
Safe for concurrent use alongside a long-running execute() call.
|
||||
Backends that maintain a persistent shell (SSH, Local) override this
|
||||
to route through their oneshot path, avoiding the shell lock.
|
||||
Non-persistent backends delegate to execute().
|
||||
"""
|
||||
return self.execute(command, cwd=cwd, timeout=timeout,
|
||||
stdin_data=stdin_data)
|
||||
|
||||
def _timeout_result(self, timeout: int | None) -> dict:
|
||||
"""Standard return dict when a command times out."""
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions.
|
|||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,22 +25,25 @@ logger = logging.getLogger(__name__)
|
|||
class DaytonaEnvironment(BaseEnvironment):
|
||||
"""Daytona cloud sandbox execution backend.
|
||||
|
||||
Uses stopped/started sandbox lifecycle for filesystem persistence
|
||||
instead of snapshots, making it faster and stateless on the host.
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls.
|
||||
cancel_fn wired to sandbox.stop() for interrupt support.
|
||||
Shell timeout wrapper preserved (SDK timeout unreliable).
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120,
|
||||
disk: int = 10240,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
self._requested_cwd = cwd
|
||||
requested_cwd = cwd
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
|
|
@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._SandboxState = SandboxState
|
||||
self._DaytonaError = DaytonaError
|
||||
self._daytona = Daytona()
|
||||
self._sandbox = None
|
||||
self._lock = threading.Lock()
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
memory_gib = max(1, math.ceil(memory / 1024))
|
||||
disk_gib = max(1, math.ceil(disk / 1024))
|
||||
if disk_gib > 10:
|
||||
warnings.warn(
|
||||
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
||||
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
|
||||
f"Capping to 10GB.",
|
||||
stacklevel=2,
|
||||
)
|
||||
disk_gib = 10
|
||||
|
|
@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
labels = {"hermes_task_id": task_id}
|
||||
sandbox_name = f"hermes-{task_id}"
|
||||
|
||||
# Try to resume an existing sandbox for this task
|
||||
if self._persistent:
|
||||
# 1. Try name-based lookup (new path)
|
||||
try:
|
||||
self._sandbox = self._daytona.get(sandbox_name)
|
||||
self._sandbox.start()
|
||||
|
|
@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
||||
if self._sandbox is None:
|
||||
try:
|
||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||
|
|
@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
if self._sandbox is None:
|
||||
self._sandbox = self._daytona.create(
|
||||
CreateSandboxFromImageParams(
|
||||
|
|
@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
|
||||
# Detect remote home dir first so mounts go to the right place.
|
||||
# Detect remote home dir
|
||||
self._remote_home = "/root"
|
||||
try:
|
||||
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
||||
if home:
|
||||
self._remote_home = home
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
if requested_cwd in ("~", "/home/daytona"):
|
||||
self.cwd = home
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd)
|
||||
|
||||
# Track synced files to avoid redundant uploads.
|
||||
# Key: remote_path, Value: (mtime, size)
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
|
||||
# Upload credential files and skills directory into the sandbox.
|
||||
self._sync_skills_and_credentials()
|
||||
self._sync_files()
|
||||
self.init_session()
|
||||
|
||||
def _upload_if_changed(self, host_path: str, remote_path: str) -> bool:
|
||||
"""Upload a file if its mtime/size changed since last sync."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
if self._synced_files.get(remote_path) == file_key:
|
||||
return False
|
||||
|
|
@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
logger.debug("Daytona: upload failed %s: %s", host_path, e)
|
||||
return False
|
||||
|
||||
def _sync_skills_and_credentials(self) -> None:
|
||||
"""Upload changed credential files and skill files into the sandbox."""
|
||||
def _sync_files(self) -> None:
|
||||
container_base = f"{self._remote_home}/.hermes"
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts, iter_skills_files
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||
if self._upload_if_changed(mount_entry["host_path"], remote_path):
|
||||
logger.debug("Daytona: synced credential %s", remote_path)
|
||||
|
||||
self._upload_if_changed(mount_entry["host_path"], remote_path)
|
||||
for entry in iter_skills_files(container_base=container_base):
|
||||
if self._upload_if_changed(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Daytona: synced skill %s", entry["container_path"])
|
||||
self._upload_if_changed(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Daytona: could not sync skills/credentials: %s", e)
|
||||
|
||||
|
|
@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
server-side timeout is not enforced and the SDK has no client-side
|
||||
fallback), so we wrap the command with the shell ``timeout`` utility
|
||||
which reliably kills the process and returns exit code 124.
|
||||
"""
|
||||
# Wrap with shell `timeout` to enforce the deadline reliably.
|
||||
# Add a small buffer so the shell timeout fires before any SDK-level
|
||||
# timeout would, giving us a clean exit code 124.
|
||||
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
|
||||
|
||||
result_holder: dict = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
"returncode": response.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
# Wait for timeout + generous buffer for network/SDK overhead
|
||||
deadline = time.monotonic() + timeout + 10
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Daytona sandbox stopped]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
# Shell timeout didn't fire and SDK is hung — force stop
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
def _before_execute(self):
|
||||
"""Ensure sandbox is ready, then rate-limited file sync via base class."""
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
# Incremental sync before each command so mid-session credential
|
||||
# refreshes and skill updates are picked up.
|
||||
self._sync_skills_and_credentials()
|
||||
super()._before_execute()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call."""
|
||||
sandbox = self._sandbox
|
||||
lock = self._lock
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
def cancel():
|
||||
with lock:
|
||||
try:
|
||||
sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Daytona sandboxes execute commands via the Daytona SDK and cannot
|
||||
# pipe subprocess stdin directly the way a local Popen can. When a
|
||||
# sudo password is present, use a shell-level pipe from printf so that
|
||||
# the password feeds sudo -S without appearing as an echo argument
|
||||
# embedded in the shell string. The password is still visible in the
|
||||
# remote sandbox's command line, but it is not exposed on the user's
|
||||
# local machine — which is the primary threat being mitigated.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
effective_cwd = cwd or self.cwd or None
|
||||
effective_timeout = timeout or self.timeout
|
||||
if login:
|
||||
shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}"
|
||||
else:
|
||||
shell_cmd = f"bash -c {shlex.quote(cmd_string)}"
|
||||
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
response = sandbox.process.exec(shell_cmd, timeout=timeout)
|
||||
return (response.result or "", response.exit_code)
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
try:
|
||||
self._ensure_sandbox_ready()
|
||||
except Exception:
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
if "error" not in result:
|
||||
return result
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
|
||||
return result
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -8,18 +8,14 @@ persistence via bind mounts.
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment):
|
|||
self._container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
||||
|
||||
# Build the init-time env forwarding args (used only by init_session
|
||||
# to inject host env vars into the snapshot; subsequent commands get
|
||||
# them from the snapshot file).
|
||||
self._init_env_args = self._build_init_env_args()
|
||||
|
||||
# Initialize session snapshot inside the container
|
||||
self.init_session()
|
||||
|
||||
def _build_init_env_args(self) -> list[str]:
|
||||
"""Build -e KEY=VALUE args for injecting host env vars into init_session.
|
||||
|
||||
These are used once during init_session() so that export -p captures
|
||||
them into the snapshot. Subsequent execute() calls don't need -e flags.
|
||||
"""
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
explicit_forward_keys = set(self._forward_env)
|
||||
passthrough_keys: set[str] = set()
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
passthrough_keys = set(get_all_passthrough())
|
||||
except Exception:
|
||||
pass
|
||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||
# keys are filtered.
|
||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
args = []
|
||||
for key in sorted(exec_env):
|
||||
args.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
return args
|
||||
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn a bash process inside the Docker container."""
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if stdin_data is not None:
|
||||
cmd.append("-i")
|
||||
|
||||
# Only inject -e env args during init_session (login=True).
|
||||
# Subsequent commands get env vars from the snapshot.
|
||||
if login:
|
||||
cmd.extend(self._init_env_args)
|
||||
|
||||
cmd.extend([self._container_id])
|
||||
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||
else:
|
||||
cmd.extend(["bash", "-c", cmd_string])
|
||||
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
|
|
@ -471,112 +530,6 @@ class DockerEnvironment(BaseEnvironment):
|
|||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# docker exec -w doesn't expand ~, so prepend a cd into the command.
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
exec_command = f"cd ~ && {exec_command}"
|
||||
work_dir = "/"
|
||||
elif work_dir.startswith("~/"):
|
||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
||||
work_dir = "/"
|
||||
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if effective_stdin is not None:
|
||||
cmd.append("-i")
|
||||
cmd.extend(["-w", work_dir])
|
||||
# Build the per-exec environment: start with explicit docker_env values
|
||||
# (static config), then overlay docker_forward_env / skill env_passthrough
|
||||
# (dynamic from host process). Forward values take precedence.
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
explicit_forward_keys = set(self._forward_env)
|
||||
passthrough_keys: set[str] = set()
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
passthrough_keys = set(get_all_passthrough())
|
||||
except Exception:
|
||||
pass
|
||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||
# keys are filtered.
|
||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
for key in sorted(exec_env):
|
||||
cmd.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
cmd.extend([self._container_id, "bash", "-lc", exec_command])
|
||||
|
||||
try:
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
||||
if self._container_id:
|
||||
|
|
|
|||
|
|
@ -1,42 +1,22 @@
|
|||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
||||
"""Local execution environment — spawn-per-call with session snapshot."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _pipe_stdin
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
# Unique marker to isolate real command output from shell init/exit noise.
|
||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
||||
|
||||
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
||||
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
|
||||
# but can break external CLIs (e.g. codex) that also honor them.
|
||||
# See: https://github.com/NousResearch/hermes-agent/issues/1002
|
||||
#
|
||||
# Built dynamically from the provider registry so new providers are
|
||||
# automatically covered without manual blocklist maintenance.
|
||||
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||
|
||||
|
||||
def _build_provider_env_blocklist() -> frozenset:
|
||||
"""Derive the blocklist from provider, tool, and gateway config.
|
||||
|
||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
||||
every registered provider, plus tool/messaging env vars from the
|
||||
optional config registry, so new Hermes-managed secrets are blocked
|
||||
in subprocesses without having to maintain multiple static lists.
|
||||
"""
|
||||
"""Derive the blocklist from provider, tool, and gateway config."""
|
||||
blocked: set[str] = set()
|
||||
|
||||
try:
|
||||
|
|
@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
||||
blocked.update({
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_API_BASE", # legacy alias
|
||||
"OPENAI_API_BASE",
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_ORGANIZATION",
|
||||
"OPENROUTER_API_KEY",
|
||||
"ANTHROPIC_BASE_URL",
|
||||
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
|
||||
"ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"LLM_MODEL",
|
||||
# Expanded isolation for other major providers (Issue #1002)
|
||||
"GOOGLE_API_KEY", # Gemini / Google AI Studio
|
||||
"DEEPSEEK_API_KEY", # DeepSeek
|
||||
"MISTRAL_API_KEY", # Mistral AI
|
||||
"GROQ_API_KEY", # Groq
|
||||
"TOGETHER_API_KEY", # Together AI
|
||||
"PERPLEXITY_API_KEY", # Perplexity
|
||||
"COHERE_API_KEY", # Cohere
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"GOOGLE_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"TOGETHER_API_KEY",
|
||||
"PERPLEXITY_API_KEY",
|
||||
"COHERE_API_KEY",
|
||||
"FIREWORKS_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"HELICONE_API_KEY",
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
|
|
@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset:
|
|||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
# Skills Hub / GitHub app auth paths and aliases.
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
# Remote sandbox backend credentials.
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"DAYTONA_API_KEY",
|
||||
|
|
@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
|||
|
||||
|
||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
||||
|
||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
||||
intentionally for callers that truly need it. Vars registered via
|
||||
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
|
||||
bypass the blocklist.
|
||||
"""
|
||||
"""Filter Hermes-managed secrets from a subprocess environment."""
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
|
|
@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non
|
|||
|
||||
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
||||
"""
|
||||
"""Find bash for command execution."""
|
||||
if not _IS_WINDOWS:
|
||||
return (
|
||||
shutil.which("bash")
|
||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||
or os.environ.get("SHELL") # last resort: whatever they have
|
||||
or os.environ.get("SHELL")
|
||||
or "/bin/sh"
|
||||
)
|
||||
|
||||
# Windows: look for Git Bash (installed with Git for Windows).
|
||||
# Allow override via env var (same pattern as Claude Code).
|
||||
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
||||
if custom and os.path.isfile(custom):
|
||||
return custom
|
||||
|
||||
# shutil.which finds bash.exe if Git\bin is on PATH
|
||||
found = shutil.which("bash")
|
||||
if found:
|
||||
return found
|
||||
|
||||
# Check common Git for Windows install locations
|
||||
for candidate in (
|
||||
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
||||
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
||||
|
|
@ -209,60 +169,7 @@ def _find_bash() -> str:
|
|||
_find_shell = _find_bash
|
||||
|
||||
|
||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
||||
# Used as a fallback when output fence markers are missing.
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
# bash
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
# zsh / oh-my-zsh / macOS terminal session
|
||||
"Restored session:",
|
||||
"Saving session...",
|
||||
"Last login:",
|
||||
"command not found:",
|
||||
"Oh My Zsh",
|
||||
"compinit:",
|
||||
)
|
||||
|
||||
|
||||
def _clean_shell_noise(output: str) -> str:
|
||||
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
|
||||
|
||||
Removes lines matching known noise patterns from both the beginning
|
||||
and end of the output. Lines in the middle are left untouched.
|
||||
"""
|
||||
|
||||
def _is_noise(line: str) -> bool:
|
||||
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
|
||||
|
||||
lines = output.split("\n")
|
||||
|
||||
# Strip leading noise
|
||||
while lines and _is_noise(lines[0]):
|
||||
lines.pop(0)
|
||||
|
||||
# Strip trailing noise (walk backwards, skip empty lines from split)
|
||||
end = len(lines) - 1
|
||||
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
|
||||
end -= 1
|
||||
|
||||
if end < 0:
|
||||
return ""
|
||||
|
||||
cleaned = lines[: end + 1]
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
# Preserve trailing newline if original had one
|
||||
if output.endswith("\n") and result and not result.endswith("\n"):
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
# Standard PATH entries for environments with minimal PATH.
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
|
@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict:
|
|||
return run_env
|
||||
|
||||
|
||||
def _extract_fenced_output(raw: str) -> str:
|
||||
"""Extract real command output from between fence markers.
|
||||
|
||||
The execute() method wraps each command with printf(FENCE) markers.
|
||||
This function finds the first and last fence and returns only the
|
||||
content between them, which is the actual command output free of
|
||||
any shell init/exit noise.
|
||||
|
||||
Falls back to pattern-based _clean_shell_noise if fences are missing.
|
||||
"""
|
||||
first = raw.find(_OUTPUT_FENCE)
|
||||
if first == -1:
|
||||
return _clean_shell_noise(raw)
|
||||
|
||||
start = first + len(_OUTPUT_FENCE)
|
||||
last = raw.rfind(_OUTPUT_FENCE)
|
||||
|
||||
if last <= first:
|
||||
# Only start fence found (e.g. user command called `exit`)
|
||||
return _clean_shell_noise(raw[start:])
|
||||
|
||||
return raw[start:last]
|
||||
|
||||
|
||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
class LocalEnvironment(BaseEnvironment):
|
||||
"""Run commands directly on the host machine.
|
||||
|
||||
Features:
|
||||
- Popen + polling for interrupt support (user can cancel mid-command)
|
||||
- Background stdout drain thread to prevent pipe buffer deadlocks
|
||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
||||
- sudo -S transform via SUDO_PASSWORD env var
|
||||
- Uses interactive login shell so full user env is available
|
||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
||||
Spawn-per-call: every execute() spawns a fresh bash process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via file-based read after each command.
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
||||
persistent: bool = False):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
self.persistent = persistent
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
self.init_session()
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-local-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
user_shell = _find_bash()
|
||||
run_env = _make_run_env(self.env)
|
||||
return subprocess.Popen(
|
||||
[user_shell, "-l"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
env=run_env,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
results = []
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
results.append(f.read())
|
||||
else:
|
||||
results.append("")
|
||||
return results
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
try:
|
||||
subprocess.run(
|
||||
["pkill", "-P", str(self._shell_pid)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
effective_timeout = timeout or self.timeout
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
user_shell = _find_bash()
|
||||
# Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line).
|
||||
# A trailing `; __hermes_rc` glued to `<<EOF` / a closing `EOF` line breaks
|
||||
# heredoc parsing: the delimiter must be alone on its line, otherwise the
|
||||
# rest of this script becomes heredoc body and leaks into stdout (e.g. gh
|
||||
# issue/PR flows that use here-documents for bodies).
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"{exec_command}\n"
|
||||
f"__hermes_rc=$?\n"
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"exit $__hermes_rc\n"
|
||||
)
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
bash = _find_bash()
|
||||
args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string]
|
||||
run_env = _make_run_env(self.env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
args,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
|
||||
return proc
|
||||
|
||||
def _kill_process(self, proc):
|
||||
"""Kill the entire process group (all children)."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
def _drain_stdout():
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Read CWD from temp file (local-only, no round-trip needed)."""
|
||||
try:
|
||||
cwd_path = open(self._cwd_file).read().strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
partial = "".join(_output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {effective_timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg if partial else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
# Still strip the marker from output so it's not visible
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
def cleanup(self):
|
||||
"""Clean up temp files."""
|
||||
for f in (self._snapshot_path, self._cwd_file):
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import uuid
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.modal_common import (
|
||||
from tools.environments.modal_utils import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
|
|
|
|||
|
|
@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.modal_common import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
_load_json_store,
|
||||
_save_json_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
|
|||
_DIRECT_SNAPSHOT_NAMESPACE = "direct"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
def _load_snapshots() -> dict:
|
||||
return _load_json_store(_SNAPSHOT_STORE)
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
def _save_snapshots(data: dict) -> None:
|
||||
_save_json_store(_SNAPSHOT_STORE, data)
|
||||
|
||||
|
||||
def _direct_snapshot_key(task_id: str) -> str:
|
||||
|
|
@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str:
|
|||
|
||||
|
||||
def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]:
|
||||
"""Return a snapshot id and whether it came from the legacy key format."""
|
||||
snapshots = _load_snapshots()
|
||||
|
||||
namespaced_key = _direct_snapshot_key(task_id)
|
||||
snapshot_id = snapshots.get(namespaced_key)
|
||||
if isinstance(snapshot_id, str) and snapshot_id:
|
||||
return snapshot_id, False
|
||||
|
||||
legacy_snapshot_id = snapshots.get(task_id)
|
||||
if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id:
|
||||
return legacy_snapshot_id, True
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||
"""Persist the direct Modal snapshot id under the direct namespace."""
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[_direct_snapshot_key(task_id)] = snapshot_id
|
||||
snapshots.pop(task_id, None)
|
||||
|
|
@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
|||
|
||||
|
||||
def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None:
|
||||
"""Remove direct Modal snapshot entries for a task, including legacy keys."""
|
||||
snapshots = _load_snapshots()
|
||||
updated = False
|
||||
|
||||
for key in (_direct_snapshot_key(task_id), task_id):
|
||||
value = snapshots.get(key)
|
||||
if value is None:
|
||||
|
|
@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non
|
|||
if snapshot_id is None or value == snapshot_id:
|
||||
snapshots.pop(key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
_save_snapshots(snapshots)
|
||||
|
||||
|
||||
def _resolve_modal_image(image_spec: Any) -> Any:
|
||||
"""Convert registry references or snapshot ids into Modal image objects."""
|
||||
"""Convert registry references or snapshot ids into Modal image objects.
|
||||
|
||||
Includes add_python support for ubuntu/debian images (absorbed from PR 4511).
|
||||
"""
|
||||
import modal as _modal
|
||||
|
||||
if not isinstance(image_spec, str):
|
||||
|
|
@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any:
|
|||
if image_spec.startswith("im-"):
|
||||
return _modal.Image.from_id(image_spec)
|
||||
|
||||
# PR 4511: add python to ubuntu/debian images that don't have it
|
||||
lower = image_spec.lower()
|
||||
add_python = any(base in lower for base in ("ubuntu", "debian"))
|
||||
|
||||
setup_commands = [
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
]
|
||||
if add_python:
|
||||
setup_commands.insert(0,
|
||||
"RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true"
|
||||
)
|
||||
|
||||
return _modal.Image.from_registry(
|
||||
image_spec,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
setup_dockerfile_commands=setup_commands,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -138,19 +135,15 @@ class _AsyncWorker:
|
|||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DirectModalExecHandle:
|
||||
thread: threading.Thread
|
||||
result_holder: Dict[str, Any]
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes.
|
||||
|
||||
|
||||
class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes."""
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls.
|
||||
cancel_fn wired to sandbox.terminate for interrupt support.
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
_poll_interval_seconds = 0.2
|
||||
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
|
||||
_unexpected_error_prefix = "Modal execution error"
|
||||
_snapshot_timeout = 60 # Modal cold starts can be slow
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -170,6 +163,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
self._app = None
|
||||
self._worker = _AsyncWorker()
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
|
||||
|
|
@ -199,27 +193,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
remote_path=mount_entry["container_path"],
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Modal: mounting credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
|
||||
# Mount individual skill files (symlinks filtered out).
|
||||
skills_files = iter_skills_files()
|
||||
for entry in skills_files:
|
||||
for entry in iter_skills_files():
|
||||
cred_mounts.append(
|
||||
_modal.Mount.from_local_file(
|
||||
entry["host_path"],
|
||||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if skills_files:
|
||||
logger.info("Modal: mounting %d skill files", len(skills_files))
|
||||
|
||||
# Mount host-side cache files (documents, images, audio,
|
||||
# screenshots). New files arriving mid-session are picked up
|
||||
# by _sync_files() before each command execution.
|
||||
cache_files = iter_cache_files()
|
||||
for entry in cache_files:
|
||||
cred_mounts.append(
|
||||
|
|
@ -228,8 +208,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if cache_files:
|
||||
logger.info("Modal: mounting %d cache files", len(cache_files))
|
||||
except Exception as e:
|
||||
logger.debug("Modal: could not load credential file mounts: %s", e)
|
||||
|
||||
|
|
@ -243,8 +221,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
existing_mounts.extend(cred_mounts)
|
||||
create_kwargs["mounts"] = existing_mounts
|
||||
sandbox = await _modal.Sandbox.create.aio(
|
||||
"sleep",
|
||||
"infinity",
|
||||
"sleep", "infinity",
|
||||
image=image_spec,
|
||||
app=app,
|
||||
timeout=int(create_kwargs.pop("timeout", 3600)),
|
||||
|
|
@ -255,57 +232,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
try:
|
||||
target_image_spec = restored_snapshot_id or image
|
||||
try:
|
||||
# _resolve_modal_image keeps the Modal bootstrap fix together:
|
||||
# it applies setup_dockerfile_commands with ensurepip before
|
||||
# Modal builds registry images, while snapshot ids restore via
|
||||
# modal.Image.from_id() without rebuilding.
|
||||
effective_image = _resolve_modal_image(target_image_spec)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(effective_image),
|
||||
timeout=300,
|
||||
_create_sandbox(effective_image), timeout=300,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not restored_snapshot_id:
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
"Modal: failed to restore snapshot %s, retrying with base image: %s",
|
||||
restored_snapshot_id[:20],
|
||||
exc,
|
||||
restored_snapshot_id[:20], exc,
|
||||
)
|
||||
_delete_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
base_image = _resolve_modal_image(image)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(base_image),
|
||||
timeout=300,
|
||||
_create_sandbox(base_image), timeout=300,
|
||||
)
|
||||
else:
|
||||
if restored_snapshot_id and restored_from_legacy_key:
|
||||
_store_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
logger.info(
|
||||
"Modal: migrated legacy snapshot entry for task %s",
|
||||
self._task_id,
|
||||
)
|
||||
except Exception:
|
||||
self._worker.stop()
|
||||
raise
|
||||
|
||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||
self.init_session()
|
||||
|
||||
def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool:
|
||||
"""Push a single file into the sandbox if changed. Returns True if synced."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
"""Push a single file into the sandbox if changed."""
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
|
||||
if self._synced_files.get(container_path) == file_key:
|
||||
return False
|
||||
|
||||
try:
|
||||
content = hp.read_bytes()
|
||||
content = Path(host_path).read_bytes()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
|
@ -326,85 +287,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
return True
|
||||
|
||||
def _sync_files(self) -> None:
|
||||
"""Push credential, skill, and cache files into the running sandbox.
|
||||
|
||||
Runs before each command. Uses mtime+size caching so only changed
|
||||
files are pushed (~13μs overhead in the no-op case). Cache files
|
||||
are especially important here — new uploads/screenshots may appear
|
||||
mid-session after sandbox creation.
|
||||
"""
|
||||
"""Push credential, skill, and cache files into the running sandbox."""
|
||||
try:
|
||||
from tools.credential_files import (
|
||||
get_credential_file_mounts,
|
||||
iter_skills_files,
|
||||
iter_cache_files,
|
||||
)
|
||||
|
||||
for entry in get_credential_file_mounts():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced credential %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_skills_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced skill file %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_cache_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced cache file %s", entry["container_path"])
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Modal: file sync failed: %s", e)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
self._sync_files()
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec."""
|
||||
sandbox = self._sandbox
|
||||
worker = self._worker
|
||||
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
|
||||
result_holder = {"value": None, "error": None}
|
||||
def cancel():
|
||||
worker.run_coroutine(sandbox.terminate.aio(), timeout=15)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
async def _do_execute():
|
||||
process = await self._sandbox.exec.aio(
|
||||
"bash",
|
||||
"-c",
|
||||
full_command,
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return self._result(output, exit_code)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
async def _do():
|
||||
args = ["bash"]
|
||||
if login:
|
||||
args.extend(["-l", "-c", cmd_string])
|
||||
else:
|
||||
args.extend(["-c", cmd_string])
|
||||
process = await sandbox.exec.aio(*args, timeout=timeout)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return output, exit_code
|
||||
|
||||
result_holder["value"] = self._worker.run_coroutine(
|
||||
_do_execute(),
|
||||
timeout=prepared.timeout + 30,
|
||||
)
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
return worker.run_coroutine(_do(), timeout=timeout + 30)
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
|
||||
|
||||
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
|
||||
if handle.thread.is_alive():
|
||||
return None
|
||||
if handle.result_holder["error"]:
|
||||
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
|
||||
return handle.result_holder["value"]
|
||||
|
||||
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
|
|
@ -426,17 +357,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
|||
_store_direct_snapshot(self._task_id, snapshot_id)
|
||||
logger.info(
|
||||
"Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20],
|
||||
self._task_id,
|
||||
snapshot_id[:20], self._task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str:
|
|||
|
||||
|
||||
class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
"""Common execute() flow for direct and managed Modal transports."""
|
||||
"""Execution flow for the *managed* Modal transport (gateway-owned sandbox).
|
||||
|
||||
This deliberately overrides :meth:`BaseEnvironment.execute` because the
|
||||
tool-gateway handles command preparation, CWD tracking, and env-snapshot
|
||||
management on the server side. The base class's ``_wrap_command`` /
|
||||
``_wait_for_process`` / snapshot machinery does not apply here — the
|
||||
gateway owns that responsibility. See ``ManagedModalEnvironment`` for the
|
||||
concrete subclass.
|
||||
"""
|
||||
|
||||
_stdin_mode = "payload"
|
||||
_poll_interval_seconds = 0.25
|
||||
|
|
@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
|||
|
||||
def _before_execute(self) -> None:
|
||||
"""Hook for backends that need pre-exec sync or validation."""
|
||||
return None
|
||||
pass
|
||||
|
||||
def _prepare_modal_exec(
|
||||
self,
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentShellMixin:
|
||||
"""Mixin that adds persistent shell capability to any BaseEnvironment.
|
||||
|
||||
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
|
||||
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
|
||||
"""
|
||||
|
||||
persistent: bool
|
||||
|
||||
@abstractmethod
|
||||
def _spawn_shell_process(self) -> subprocess.Popen: ...
|
||||
|
||||
@abstractmethod
|
||||
def _read_temp_files(self, *paths: str) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _kill_shell_children(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def _execute_oneshot(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
def _cleanup_temp_files(self): ...
|
||||
|
||||
_session_id: str = ""
|
||||
_poll_interval_start: float = 0.01 # initial poll interval (10ms)
|
||||
_poll_interval_max: float = 0.25 # max poll interval (250ms) — reduces I/O for long commands
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-persistent-{self._session_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_persistent_shell(self):
|
||||
self._shell_lock = threading.Lock()
|
||||
self._shell_proc: subprocess.Popen | None = None
|
||||
self._shell_alive: bool = False
|
||||
self._shell_pid: int | None = None
|
||||
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
p = self._temp_prefix
|
||||
self._pshell_stdout = f"{p}-stdout"
|
||||
self._pshell_stderr = f"{p}-stderr"
|
||||
self._pshell_status = f"{p}-status"
|
||||
self._pshell_cwd = f"{p}-cwd"
|
||||
self._pshell_pid_file = f"{p}-pid"
|
||||
|
||||
self._shell_proc = self._spawn_shell_process()
|
||||
self._shell_alive = True
|
||||
|
||||
self._drain_thread = threading.Thread(
|
||||
target=self._drain_shell_output, daemon=True,
|
||||
)
|
||||
self._drain_thread.start()
|
||||
|
||||
init_script = (
|
||||
f"export TERM=${{TERM:-dumb}}\n"
|
||||
f"touch {self._pshell_stdout} {self._pshell_stderr} "
|
||||
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
|
||||
f"echo $$ > {self._pshell_pid_file}\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
)
|
||||
self._send_to_shell(init_script)
|
||||
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
|
||||
if pid_str.isdigit():
|
||||
self._shell_pid = int(pid_str)
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
logger.warning("Could not read persistent shell PID")
|
||||
self._shell_pid = None
|
||||
|
||||
if self._shell_pid:
|
||||
logger.info(
|
||||
"Persistent shell started (session=%s, pid=%d)",
|
||||
self._session_id, self._shell_pid,
|
||||
)
|
||||
|
||||
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
|
||||
if reported_cwd:
|
||||
self.cwd = reported_cwd
|
||||
|
||||
def _cleanup_persistent_shell(self):
|
||||
if self._shell_proc is None:
|
||||
return
|
||||
|
||||
if self._session_id:
|
||||
self._cleanup_temp_files()
|
||||
|
||||
try:
|
||||
self._shell_proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._shell_proc.terminate()
|
||||
self._shell_proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._shell_proc.kill()
|
||||
|
||||
self._shell_alive = False
|
||||
self._shell_proc = None
|
||||
|
||||
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
|
||||
self._drain_thread.join(timeout=1.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# execute() / cleanup() — shared dispatcher, subclasses inherit
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if self.persistent:
|
||||
return self._execute_persistent(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Always use the oneshot (non-persistent) execution path.
|
||||
|
||||
This bypasses _shell_lock so it can run concurrently with a
|
||||
long-running command in the persistent shell — used by
|
||||
execute_code's file-based RPC polling thread.
|
||||
"""
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.persistent:
|
||||
self._cleanup_persistent_shell()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shell I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drain_shell_output(self):
|
||||
try:
|
||||
for _ in self._shell_proc.stdout:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._shell_alive = False
|
||||
|
||||
def _send_to_shell(self, text: str):
|
||||
if not self._shell_alive or self._shell_proc is None:
|
||||
return
|
||||
try:
|
||||
self._shell_proc.stdin.write(text)
|
||||
self._shell_proc.stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
self._shell_alive = False
|
||||
|
||||
def _read_persistent_output(self) -> tuple[str, int, str]:
|
||||
stdout, stderr, status_raw, cwd = self._read_temp_files(
|
||||
self._pshell_stdout, self._pshell_stderr,
|
||||
self._pshell_status, self._pshell_cwd,
|
||||
)
|
||||
output = self._merge_output(stdout, stderr)
|
||||
status = status_raw.strip()
|
||||
if ":" in status:
|
||||
status = status.split(":", 1)[1]
|
||||
try:
|
||||
exit_code = int(status.strip())
|
||||
except ValueError:
|
||||
exit_code = 1
|
||||
return output, exit_code, cwd.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_persistent(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._shell_alive:
|
||||
logger.info("Persistent shell died, restarting...")
|
||||
self._init_persistent_shell()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
if stdin_data or sudo_stdin:
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
with self._shell_lock:
|
||||
return self._execute_persistent_locked(
|
||||
exec_command, cwd, effective_timeout,
|
||||
)
|
||||
|
||||
def _execute_persistent_locked(self, command: str, cwd: str,
|
||||
timeout: int) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
cmd_id = uuid.uuid4().hex[:8]
|
||||
truncate = (
|
||||
f": > {self._pshell_stdout}\n"
|
||||
f": > {self._pshell_stderr}\n"
|
||||
f": > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(truncate)
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
ipc_script = (
|
||||
f"cd {shlex.quote(work_dir)}\n"
|
||||
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
|
||||
f"__EC=$?\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(ipc_script)
|
||||
deadline = time.monotonic() + timeout
|
||||
poll_interval = self._poll_interval_start # starts at 10ms, backs off to 250ms
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
return {
|
||||
"output": output + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
if output:
|
||||
return {
|
||||
"output": output + f"\n[Command timed out after {timeout}s]",
|
||||
"returncode": 124,
|
||||
}
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if not self._shell_alive:
|
||||
return {
|
||||
"output": "Persistent shell died during execution",
|
||||
"returncode": 1,
|
||||
}
|
||||
|
||||
status_content = self._read_temp_files(self._pshell_status)[0].strip()
|
||||
if status_content.startswith(cmd_id + ":"):
|
||||
break
|
||||
|
||||
time.sleep(poll_interval)
|
||||
# Exponential backoff: fast start (10ms) for quick commands,
|
||||
# ramps up to 250ms for long-running commands — reduces I/O by 10-25x
|
||||
# on WSL2 where polling keeps the VM hot and memory pressure high.
|
||||
poll_interval = min(poll_interval * 1.5, self._poll_interval_max)
|
||||
|
||||
output, exit_code, new_cwd = self._read_persistent_output()
|
||||
if new_cwd:
|
||||
self.cwd = new_cwd
|
||||
return {"output": output, "returncode": exit_code}
|
||||
|
||||
@staticmethod
|
||||
def _merge_output(stdout: str, stderr: str) -> str:
|
||||
parts = []
|
||||
if stdout.strip():
|
||||
parts.append(stdout.rstrip("\n"))
|
||||
if stderr.strip():
|
||||
parts.append(stderr.rstrip("\n"))
|
||||
return "\n".join(parts)
|
||||
|
|
@ -5,20 +5,22 @@ Supports configurable resource limits and optional filesystem persistence
|
|||
via writable overlay directories that survive across sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_load_json_store,
|
||||
_popen_bash,
|
||||
_save_json_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,11 +28,7 @@ _SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json"
|
|||
|
||||
|
||||
def _find_singularity_executable() -> str:
|
||||
"""Locate the apptainer or singularity CLI binary.
|
||||
|
||||
Returns the executable name (``"apptainer"`` or ``"singularity"``).
|
||||
Raises ``RuntimeError`` with install instructions if neither is found.
|
||||
"""
|
||||
"""Locate the apptainer or singularity CLI binary."""
|
||||
if shutil.which("apptainer"):
|
||||
return "apptainer"
|
||||
if shutil.which("singularity"):
|
||||
|
|
@ -43,66 +41,34 @@ def _find_singularity_executable() -> str:
|
|||
|
||||
|
||||
def _ensure_singularity_available() -> str:
|
||||
"""Preflight check: resolve the executable and verify it responds.
|
||||
|
||||
Returns the executable name on success.
|
||||
Raises ``RuntimeError`` with an actionable message on failure.
|
||||
"""
|
||||
"""Preflight check: resolve the executable and verify it responds."""
|
||||
exe = _find_singularity_executable()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[exe, "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
[exe, "version"], capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"Singularity backend selected but the resolved executable '{exe}' "
|
||||
"could not be executed. Check your installation."
|
||||
f"Singularity backend selected but '{exe}' could not be executed."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' timed out. The runtime may be misconfigured."
|
||||
)
|
||||
raise RuntimeError(f"'{exe} version' timed out.")
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()[:200]
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' failed (exit code {result.returncode}): {stderr}"
|
||||
)
|
||||
|
||||
raise RuntimeError(f"'{exe} version' failed (exit code {result.returncode}): {stderr}")
|
||||
return exe
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
def _load_snapshots() -> dict:
|
||||
return _load_json_store(_SNAPSHOT_STORE)
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
def _save_snapshots(data: dict) -> None:
|
||||
_save_json_store(_SNAPSHOT_STORE, data)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
Resolution order:
|
||||
1. TERMINAL_SCRATCH_DIR (explicit override)
|
||||
2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root)
|
||||
3. /scratch (common on HPC clusters)
|
||||
4. ~/.hermes/sandboxes/singularity (fallback)
|
||||
"""
|
||||
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
|
||||
if custom_scratch:
|
||||
scratch_path = Path(custom_scratch)
|
||||
|
|
@ -124,7 +90,6 @@ def _get_scratch_dir() -> Path:
|
|||
|
||||
|
||||
def _get_apptainer_cache_dir() -> Path:
|
||||
"""Get the Apptainer cache directory for SIF images."""
|
||||
cache_dir = os.getenv("APPTAINER_CACHEDIR")
|
||||
if cache_dir:
|
||||
cache_path = Path(cache_dir)
|
||||
|
|
@ -140,11 +105,6 @@ _sif_build_lock = threading.Lock()
|
|||
|
||||
|
||||
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
"""Get or build a SIF image from a docker:// URL.
|
||||
|
||||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith('docker://'):
|
||||
|
|
@ -193,19 +153,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
|||
return image
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
|
||||
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
|
||||
the host filesystem outside of explicitly bound paths.
|
||||
|
||||
Persistence: when enabled, the writable overlay directory is preserved across
|
||||
sessions so installed packages and files survive cleanup/restore.
|
||||
Spawn-per-call: every execute() spawns a fresh ``apptainer exec ... bash -c`` process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via in-band stdout markers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -227,12 +180,9 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
self._memory = memory
|
||||
|
||||
# Persistent overlay directory
|
||||
if self._persistent:
|
||||
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
||||
overlay_base.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -240,42 +190,26 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._start_instance()
|
||||
self.init_session()
|
||||
|
||||
def _start_instance(self):
|
||||
cmd = [self.executable, "instance", "start"]
|
||||
|
||||
# Security: full isolation from host
|
||||
cmd.extend(["--containall", "--no-home"])
|
||||
|
||||
# Writable layer
|
||||
if self._persistent and self._overlay_dir:
|
||||
# Persistent writable overlay -- survives across restarts
|
||||
cmd.extend(["--overlay", str(self._overlay_dir)])
|
||||
else:
|
||||
cmd.append("--writable-tmpfs")
|
||||
|
||||
# Mount credential files and skills directory (read-only).
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts, get_skills_directory_mount
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
cmd.extend(["--bind", f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro"])
|
||||
logger.info(
|
||||
"Singularity: binding credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
for skills_mount in get_skills_directory_mount():
|
||||
cmd.extend(["--bind", f"{skills_mount['host_path']}:{skills_mount['container_path']}:ro"])
|
||||
logger.info(
|
||||
"Singularity: binding skills dir %s -> %s",
|
||||
skills_mount["host_path"],
|
||||
skills_mount["container_path"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Singularity: could not load credential/skills mounts: %s", e)
|
||||
|
||||
# Resource limits (cgroup-based, may require root or appropriate config)
|
||||
if self._memory > 0:
|
||||
cmd.extend(["--memory", f"{self._memory}M"])
|
||||
if self._cpu > 0:
|
||||
|
|
@ -288,94 +222,29 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn a bash process inside the Singularity instance."""
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
raise RuntimeError("Singularity instance not started")
|
||||
|
||||
effective_timeout = timeout or self.timeout
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
cmd = [self.executable, "exec",
|
||||
f"instance://{self.instance_id}"]
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
cmd.extend(["bash", "-c", cmd_string])
|
||||
|
||||
# apptainer exec --pwd doesn't expand ~, so prepend a cd into the command.
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
exec_command = f"cd ~ && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
elif work_dir.startswith("~/"):
|
||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = _time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if _time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
_time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Singularity execution error: {e}", "returncode": 1}
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
|
||||
"""Stop the instance. If persistent, the overlay dir survives."""
|
||||
if self._instance_started:
|
||||
try:
|
||||
subprocess.run(
|
||||
|
|
@ -387,7 +256,6 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
||||
self._instance_started = False
|
||||
|
||||
# Record overlay path for persistence restoration
|
||||
if self._persistent and self._overlay_dir:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = str(self._overlay_dir)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,9 @@ import shlex
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,32 +20,22 @@ def _ensure_ssh_available() -> None:
|
|||
)
|
||||
|
||||
|
||||
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
class SSHEnvironment(BaseEnvironment):
|
||||
"""Run commands on a remote machine over SSH.
|
||||
|
||||
Uses SSH ControlMaster for connection persistence so subsequent
|
||||
commands are fast. Security benefit: the agent cannot modify its
|
||||
own code since execution happens on a separate machine.
|
||||
|
||||
Foreground commands are interruptible: the local ssh process is killed
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
|
||||
When ``persistent=True``, a single long-lived bash shell is kept alive
|
||||
over SSH and state (cwd, env vars, shell variables) persists across
|
||||
``execute()`` calls. Output capture uses file-based IPC on the remote
|
||||
host (stdout/stderr/exit-code written to temp files, polled via fast
|
||||
ControlMaster one-shot reads).
|
||||
Spawn-per-call: every execute() spawns a fresh ``ssh ... bash -c`` process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via in-band stdout markers.
|
||||
Uses SSH ControlMaster for connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = "",
|
||||
persistent: bool = False):
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
self.port = port
|
||||
self.key_path = key_path
|
||||
self.persistent = persistent
|
||||
|
||||
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
||||
self.control_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -57,10 +43,10 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||
_ensure_ssh_available()
|
||||
self._establish_connection()
|
||||
self._remote_home = self._detect_remote_home()
|
||||
self._sync_skills_and_credentials()
|
||||
self._last_sync_time: float = 0 # guarantees first _before_execute syncs
|
||||
self._sync_files()
|
||||
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
self.init_session()
|
||||
|
||||
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
||||
cmd = ["ssh"]
|
||||
|
|
@ -102,12 +88,11 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||
return home
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: guess from username
|
||||
if self.user == "root":
|
||||
return "/root"
|
||||
return f"/home/{self.user}"
|
||||
|
||||
def _sync_skills_and_credentials(self) -> None:
|
||||
def _sync_files(self) -> None:
|
||||
"""Rsync skills directory and credential files to the remote host."""
|
||||
try:
|
||||
container_base = f"{self._remote_home}/.hermes"
|
||||
|
|
@ -122,7 +107,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||
rsync_base.extend(["-e", ssh_opts])
|
||||
dest_prefix = f"{self.user}@{self.host}"
|
||||
|
||||
# Sync individual credential files (remap /root/.hermes to detected home)
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||
parent_dir = str(Path(remote_path).parent)
|
||||
|
|
@ -136,7 +120,6 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||
else:
|
||||
logger.debug("SSH: rsync credential failed: %s", result.stderr.strip())
|
||||
|
||||
# Sync skill directories (local + external, remap to detected home)
|
||||
for skills_mount in get_skills_directory_mount(container_base=container_base):
|
||||
remote_path = skills_mount["container_path"]
|
||||
mkdir_cmd = self._build_ssh_command()
|
||||
|
|
@ -154,152 +137,19 @@ class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
|||
except Exception as e:
|
||||
logger.debug("SSH: could not sync skills/credentials: %s", e)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
# Incremental sync before each command so mid-session credential
|
||||
# refreshes and skill updates are picked up.
|
||||
self._sync_skills_and_credentials()
|
||||
return super().execute(command, cwd, timeout=timeout, stdin_data=stdin_data)
|
||||
|
||||
_poll_interval_start: float = 0.15 # SSH: higher initial interval (150ms) for network latency
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-ssh-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn an SSH process that runs bash on the remote host."""
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append("bash -l")
|
||||
return subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
if len(paths) == 1:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"cat {paths[0]} 2>/dev/null")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
return [result.stdout]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""]
|
||||
|
||||
delim = f"__HERMES_SEP_{self._session_id}__"
|
||||
script = "; ".join(
|
||||
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
|
||||
)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(script)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
parts = result.stdout.split(delim + "\n")
|
||||
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""] * len(paths)
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"rm -f {self._temp_prefix}-*")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
wrapped = f'cd ~ && {exec_command}'
|
||||
elif work_dir.startswith("~/"):
|
||||
wrapped = f'cd ~/{shlex.quote(work_dir[2:])} && {exec_command}'
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", shlex.quote(cmd_string)])
|
||||
else:
|
||||
wrapped = f'cd {shlex.quote(work_dir)} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
cmd.extend(["bash", "-c", shlex.quote(cmd_string)])
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(wrapped)
|
||||
|
||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
||||
kwargs.pop("timeout", None)
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ def _handle_send(args):
|
|||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
"matrix": Platform.MATRIX,
|
||||
"mattermost": Platform.MATTERMOST,
|
||||
"homeassistant": Platform.HOMEASSISTANT,
|
||||
|
|
@ -396,6 +397,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
result = await _send_feishu(pconfig, chat_id, chunk, thread_id=thread_id)
|
||||
elif platform == Platform.WECOM:
|
||||
result = await _send_wecom(pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.BLUEBUBBLES:
|
||||
result = await _send_bluebubbles(pconfig.extra, chat_id, chunk)
|
||||
else:
|
||||
result = {"error": f"Direct sending not yet implemented for {platform.value}"}
|
||||
|
||||
|
|
@ -870,6 +873,33 @@ async def _send_wecom(extra, chat_id, message):
|
|||
return _error(f"WeCom send failed: {e}")
|
||||
|
||||
|
||||
async def _send_bluebubbles(extra, chat_id, message):
|
||||
"""Send via BlueBubbles iMessage server using the adapter's REST API."""
|
||||
try:
|
||||
from gateway.platforms.bluebubbles import BlueBubblesAdapter, check_bluebubbles_requirements
|
||||
if not check_bluebubbles_requirements():
|
||||
return {"error": "BlueBubbles requirements not met (need aiohttp + httpx)."}
|
||||
except ImportError:
|
||||
return {"error": "BlueBubbles adapter not available."}
|
||||
|
||||
try:
|
||||
from gateway.config import PlatformConfig
|
||||
pconfig = PlatformConfig(extra=extra)
|
||||
adapter = BlueBubblesAdapter(pconfig)
|
||||
connected = await adapter.connect()
|
||||
if not connected:
|
||||
return _error("BlueBubbles: failed to connect to server")
|
||||
try:
|
||||
result = await adapter.send(chat_id, message)
|
||||
if not result.success:
|
||||
return _error(f"BlueBubbles send failed: {result.error}")
|
||||
return {"success": True, "platform": "bluebubbles", "chat_id": chat_id, "message_id": result.message_id}
|
||||
finally:
|
||||
await adapter.disconnect()
|
||||
except Exception as e:
|
||||
return _error(f"BlueBubbles send failed: {e}")
|
||||
|
||||
|
||||
async def _send_feishu(pconfig, chat_id, message, media_files=None, thread_id=None):
|
||||
"""Send via Feishu/Lark using the adapter's send pipeline."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -326,8 +326,123 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
if "HERMES_SPINNER_PAUSE" in os.environ:
|
||||
del os.environ["HERMES_SPINNER_PAUSE"]
|
||||
|
||||
def _safe_command_preview(command: Any, limit: int = 200) -> str:
|
||||
"""Return a log-safe preview for possibly-invalid command values."""
|
||||
if command is None:
|
||||
return "<None>"
|
||||
if isinstance(command, str):
|
||||
return command[:limit]
|
||||
try:
|
||||
return repr(command)[:limit]
|
||||
except Exception:
|
||||
return f"<{type(command).__name__}>"
|
||||
|
||||
def _transform_sudo_command(command: str) -> tuple[str, str | None]:
|
||||
def _looks_like_env_assignment(token: str) -> bool:
|
||||
"""Return True when *token* is a leading shell environment assignment."""
|
||||
if "=" not in token or token.startswith("="):
|
||||
return False
|
||||
name, _value = token.split("=", 1)
|
||||
return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name))
|
||||
|
||||
|
||||
def _read_shell_token(command: str, start: int) -> tuple[str, int]:
|
||||
"""Read one shell token, preserving quotes/escapes, starting at *start*."""
|
||||
i = start
|
||||
n = len(command)
|
||||
|
||||
while i < n:
|
||||
ch = command[i]
|
||||
if ch.isspace() or ch in ";|&()":
|
||||
break
|
||||
if ch == "'":
|
||||
i += 1
|
||||
while i < n and command[i] != "'":
|
||||
i += 1
|
||||
if i < n:
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
i += 1
|
||||
while i < n:
|
||||
inner = command[i]
|
||||
if inner == "\\" and i + 1 < n:
|
||||
i += 2
|
||||
continue
|
||||
if inner == '"':
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\" and i + 1 < n:
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
|
||||
return command[start:i], i
|
||||
|
||||
|
||||
def _rewrite_real_sudo_invocations(command: str) -> tuple[str, bool]:
|
||||
"""Rewrite only real unquoted sudo command words, not plain text mentions."""
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
n = len(command)
|
||||
command_start = True
|
||||
found = False
|
||||
|
||||
while i < n:
|
||||
ch = command[i]
|
||||
|
||||
if ch.isspace():
|
||||
out.append(ch)
|
||||
if ch == "\n":
|
||||
command_start = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if ch == "#" and command_start:
|
||||
comment_end = command.find("\n", i)
|
||||
if comment_end == -1:
|
||||
out.append(command[i:])
|
||||
break
|
||||
out.append(command[i:comment_end])
|
||||
i = comment_end
|
||||
continue
|
||||
|
||||
if command.startswith("&&", i) or command.startswith("||", i) or command.startswith(";;", i):
|
||||
out.append(command[i:i + 2])
|
||||
i += 2
|
||||
command_start = True
|
||||
continue
|
||||
|
||||
if ch in ";|&(":
|
||||
out.append(ch)
|
||||
i += 1
|
||||
command_start = True
|
||||
continue
|
||||
|
||||
if ch == ")":
|
||||
out.append(ch)
|
||||
i += 1
|
||||
command_start = False
|
||||
continue
|
||||
|
||||
token, next_i = _read_shell_token(command, i)
|
||||
if command_start and token == "sudo":
|
||||
out.append("sudo -S -p ''")
|
||||
found = True
|
||||
else:
|
||||
out.append(token)
|
||||
|
||||
if command_start and _looks_like_env_assignment(token):
|
||||
command_start = True
|
||||
else:
|
||||
command_start = False
|
||||
i = next_i
|
||||
|
||||
return "".join(out), found
|
||||
|
||||
|
||||
def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Transform sudo commands to use -S flag if SUDO_PASSWORD is available.
|
||||
|
||||
|
|
@ -362,37 +477,26 @@ def _transform_sudo_command(command: str) -> tuple[str, str | None]:
|
|||
Command runs as-is (fails gracefully with "sudo: a password is required").
|
||||
"""
|
||||
global _cached_sudo_password
|
||||
import re
|
||||
|
||||
# Check if command even contains sudo
|
||||
if not re.search(r'\bsudo\b', command):
|
||||
return command, None # No sudo in command, nothing to do
|
||||
if command is None:
|
||||
return None, None
|
||||
transformed, has_real_sudo = _rewrite_real_sudo_invocations(command)
|
||||
if not has_real_sudo:
|
||||
return command, None
|
||||
|
||||
# Try to get password from: env var -> session cache -> interactive prompt
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password
|
||||
has_configured_password = "SUDO_PASSWORD" in os.environ
|
||||
sudo_password = os.environ.get("SUDO_PASSWORD", "") if has_configured_password else _cached_sudo_password
|
||||
|
||||
if not sudo_password:
|
||||
# No password configured - check if we're in interactive mode
|
||||
if os.getenv("HERMES_INTERACTIVE"):
|
||||
# Prompt user for password
|
||||
sudo_password = _prompt_for_sudo_password(timeout_seconds=45)
|
||||
if sudo_password:
|
||||
_cached_sudo_password = sudo_password # Cache for session
|
||||
if not has_configured_password and not sudo_password and os.getenv("HERMES_INTERACTIVE"):
|
||||
sudo_password = _prompt_for_sudo_password(timeout_seconds=45)
|
||||
if sudo_password:
|
||||
_cached_sudo_password = sudo_password
|
||||
|
||||
if not sudo_password:
|
||||
return command, None # No password, let it fail gracefully
|
||||
if has_configured_password or sudo_password:
|
||||
# Trailing newline is required: sudo -S reads one line for the password.
|
||||
return transformed, sudo_password + "\n"
|
||||
|
||||
def replace_sudo(match):
|
||||
# Replace bare 'sudo' with 'sudo -S -p ""'.
|
||||
# The password is returned as sudo_stdin and must be written to the
|
||||
# process's stdin pipe by the caller — it never appears in any
|
||||
# command-line argument or shell string.
|
||||
return "sudo -S -p ''"
|
||||
|
||||
# Match 'sudo' at word boundaries (not 'visudo' or 'sudoers')
|
||||
transformed = re.sub(r'\bsudo\b', replace_sudo, command)
|
||||
# Trailing newline is required: sudo -S reads one line for the password.
|
||||
return transformed, sudo_password + "\n"
|
||||
return command, None
|
||||
|
||||
|
||||
# Environment classes now live in tools/environments/
|
||||
|
|
@ -611,9 +715,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
docker_env = cc.get("docker_env", {})
|
||||
|
||||
if env_type == "local":
|
||||
lc = local_config or {}
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout,
|
||||
persistent=lc.get("persistent", False))
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
elif env_type == "docker":
|
||||
return _DockerEnvironment(
|
||||
|
|
@ -705,7 +807,6 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
key_path=ssh_config.get("key", ""),
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
persistent=ssh_config.get("persistent", False),
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
@ -817,6 +918,23 @@ def get_active_env(task_id: str):
|
|||
return _active_environments.get(task_id)
|
||||
|
||||
|
||||
def is_persistent_env(task_id: str) -> bool:
|
||||
"""Return True if the active environment for task_id is configured for
|
||||
cross-turn persistence (``persistent_filesystem=True``).
|
||||
|
||||
Used by the agent loop to skip per-turn teardown for backends whose whole
|
||||
point is to survive between turns (docker with ``container_persistent``,
|
||||
daytona, modal, etc.). Non-persistent backends (e.g. Morph) still get torn
|
||||
down at end-of-turn to prevent leakage. The idle reaper
|
||||
(``_cleanup_inactive_envs``) handles persistent envs once they exceed
|
||||
``terminal.lifetime_seconds``.
|
||||
"""
|
||||
env = get_active_env(task_id)
|
||||
if env is None:
|
||||
return False
|
||||
return bool(getattr(env, "_persistent", False))
|
||||
|
||||
|
||||
def get_active_environments_info() -> Dict[str, Any]:
|
||||
"""Get information about currently active environments."""
|
||||
info = {
|
||||
|
|
@ -1036,6 +1154,18 @@ def terminal_tool(
|
|||
# Note: force parameter is internal only, not exposed to model API
|
||||
"""
|
||||
try:
|
||||
if not isinstance(command, str):
|
||||
logger.warning(
|
||||
"Rejected invalid terminal command value: %s",
|
||||
type(command).__name__,
|
||||
)
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Invalid command: expected string, got {type(command).__name__}",
|
||||
"status": "error",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Get configuration
|
||||
config = _get_env_config()
|
||||
env_type = config["env_type"]
|
||||
|
|
@ -1193,7 +1323,7 @@ def terminal_tool(
|
|||
workdir_error = _validate_workdir(workdir)
|
||||
if workdir_error:
|
||||
logger.warning("Blocked dangerous workdir: %s (command: %s)",
|
||||
workdir[:200], command[:200])
|
||||
workdir[:200], _safe_command_preview(command))
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
|
|
@ -1333,12 +1463,12 @@ def terminal_tool(
|
|||
retry_count += 1
|
||||
wait_time = 2 ** retry_count
|
||||
logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
|
||||
wait_time, retry_count, max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
|
||||
max_retries, _safe_command_preview(command), type(e).__name__, e, effective_task_id, env_type)
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
|
|
|
|||
|
|
@ -311,6 +311,12 @@ TOOLSETS = {
|
|||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-bluebubbles": {
|
||||
"description": "BlueBubbles iMessage bot toolset - Apple iMessage via local BlueBubbles server",
|
||||
"tools": _HERMES_CORE_TOOLS,
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-homeassistant": {
|
||||
"description": "Home Assistant bot toolset - smart home event monitoring and control",
|
||||
"tools": _HERMES_CORE_TOOLS,
|
||||
|
|
@ -368,7 +374,7 @@ TOOLSETS = {
|
|||
"hermes-gateway": {
|
||||
"description": "Gateway toolset - union of all messaging platform tools",
|
||||
"tools": [],
|
||||
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"]
|
||||
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-webhook"]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -116,9 +116,9 @@ hermes-agent/
|
|||
│ ├── mirror.py # Cross-session message mirroring
|
||||
│ ├── status.py # Token locks, profile-scoped process tracking
|
||||
│ ├── builtin_hooks/ # Always-registered hooks
|
||||
│ └── platforms/ # 14 adapters: telegram, discord, slack, whatsapp,
|
||||
│ └── platforms/ # 15 adapters: telegram, discord, slack, whatsapp,
|
||||
│ # signal, matrix, mattermost, email, sms,
|
||||
│ # dingtalk, feishu, wecom, homeassistant, webhook
|
||||
│ # dingtalk, feishu, wecom, bluebubbles, homeassistant, webhook
|
||||
│
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains)
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ Cron job results can be delivered to any supported platform:
|
|||
| DingTalk | `dingtalk` | Deliver to DingTalk |
|
||||
| Feishu | `feishu` | Deliver to Feishu |
|
||||
| WeCom | `wecom` | Deliver to WeCom |
|
||||
| BlueBubbles | `bluebubbles` | Deliver to iMessage via BlueBubbles |
|
||||
|
||||
For Telegram topics, use the format `telegram:<chat_id>:<thread_id>` (e.g., `telegram:-1001234567890:17585`).
|
||||
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ gateway/platforms/
|
|||
├── dingtalk.py # DingTalk WebSocket
|
||||
├── feishu.py # Feishu/Lark WebSocket or webhook
|
||||
├── wecom.py # WeCom (WeChat Work) callback
|
||||
├── bluebubbles.py # Apple iMessage via BlueBubbles macOS server
|
||||
├── webhook.py # Inbound/outbound webhook adapter
|
||||
├── api_server.py # REST API server adapter
|
||||
└── homeassistant.py # Home Assistant conversation integration
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ This module requires NixOS. For non-NixOS systems (macOS, other Linux distros),
|
|||
# /etc/nixos/flake.nix (or your system flake)
|
||||
{
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
hermes-agent.url = "github:NousResearch/hermes-agent";
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ It's not a coding copilot tethered to an IDE or a chatbot wrapper around a singl
|
|||
|
||||
- **A closed learning loop** — Agent-curated memory with periodic nudges, autonomous skill creation, skill self-improvement during use, FTS5 cross-session recall with LLM summarization, and [Honcho](https://github.com/plastic-labs/honcho) dialectic user modeling
|
||||
- **Runs anywhere, not just your laptop** — 6 terminal backends: local, Docker, SSH, Daytona, Singularity, Modal. Daytona and Modal offer serverless persistence — your environment hibernates when idle, costing nearly nothing
|
||||
- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, Home Assistant — 14+ platforms from one gateway
|
||||
- **Lives where you do** — CLI, Telegram, Discord, Slack, WhatsApp, Signal, Matrix, Mattermost, Email, SMS, DingTalk, Feishu, WeCom, BlueBubbles, Home Assistant — 15+ platforms from one gateway
|
||||
- **Built by model trainers** — Created by [Nous Research](https://nousresearch.com), the lab behind Hermes, Nomos, and Psyche. Works with [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI, or any endpoint
|
||||
- **Scheduled automations** — Built-in cron with delivery to any platform
|
||||
- **Delegates & parallelizes** — Spawn isolated subagents for parallel workstreams. Programmatic Tool Calling via `execute_code` collapses multi-step pipelines into single inference calls
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue