diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..2363d4ca8a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{yml,yaml,json,toml}] +indent_size = 2 + +[*.md] +trim_trailing_whitespace = false + +[Makefile] +indent_style = tab diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5496eb534f..1ab6c0d4e7 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -46,7 +46,7 @@ Fixes # - [ ] My commit messages follow [Conventional Commits](https://www.conventionalcommits.org/) (`fix(scope):`, `feat(scope):`, etc.) - [ ] I searched for [existing PRs](https://github.com/NousResearch/hermes-agent/pulls) to make sure this isn't a duplicate - [ ] My PR contains **only** changes related to this fix/feature (no unrelated commits) -- [ ] I've run `pytest tests/ -q` and all tests pass +- [ ] I've run `make check` (lint + test) and all checks pass - [ ] I've added tests for my changes (required for bug fixes, strongly encouraged for features) - [ ] I've tested on my platform: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ebaa7f4b8..6a7da04927 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Tests +name: CI on: push: @@ -6,37 +6,42 @@ on: pull_request: branches: [main] -# Cancel in-progress runs for the same PR/branch concurrency: - group: tests-${{ github.ref }} + group: ci-${{ github.ref }} cancel-in-progress: true +env: + SRC: >- + run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py + tools/ hermes_cli/ gateway/ agent/ cron/ + jobs: + lint: + runs-on: ubuntu-latest + timeout-minutes: 3 + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uvx ruff check $SRC + - run: uvx ruff format --check $SRC + test: runs-on: ubuntu-latest timeout-minutes: 10 steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v5 - - - name: Set up Python 3.11 - run: uv python install 3.11 - - - name: Install dependencies - run: | + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + - run: uv python install 3.11 + - run: | uv venv .venv --python 3.11 source .venv/bin/activate uv pip install -e ".[all,dev]" - - - name: Run tests - run: | + - run: | source .venv/bin/activate python -m pytest tests/ -q --ignore=tests/integration --tb=short env: - # Ensure tests don't accidentally call real APIs OPENROUTER_API_KEY: "" OPENAI_API_KEY: "" NOUS_API_KEY: "" diff --git a/.gitignore b/.gitignore index 78a3829429..bef3c773bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,51 +1,53 @@ -/venv/ -/_pycache/ -*.pyc* +# Python __pycache__/ +*.pyc +*.pyo +*.egg-info/ +dist/ +build/ + +# Environments .venv/ +venv/ + +# Tools +.ruff_cache/ +.mypy_cache/ +.pytest_cache/ + +# Editors .vscode/ +.idea/ + +# Secrets & config .env .env.local -.env.development.local -.env.test.local -.env.production.local -.env.development -.env.test -export* -__pycache__/model_tools.cpython-310.pyc -__pycache__/web_tools.cpython-310.pyc +.env.*.local +*.pem +*.ppk + +# Node +node_modules/ + +# Project-specific logs/ data/ -.pytest_cache/ tmp/ -temp_vision_images/ -hermes-*/* -examples/ -tests/quick_test_dataset.jsonl -tests/sample_dataset.jsonl -run_datagen_kimik2-thinking.sh -run_datagen_megascience_glm4-6.sh -run_datagen_sonnet.sh -source-data/* -run_datagen_megascience_glm4-6.sh -data/* -node_modules/ +wandb/ +images/ browser-use/ agent-browser/ -# Private keys -*.ppk -*.pem -privvy* -images/ -__pycache__/ -hermes_agent.egg-info/ -wandb/ -testlogs - -# CLI config (may contain sensitive SSH paths) +source-data/ +testlogs/ +ignored/ +.worktrees/ +temp_vision_images/ cli-config.yaml - -# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case) skills/.hub/ -ignored/ -.worktrees/ +hermes-*/* +examples/ +export* +privvy* +run_datagen_*.sh +tests/quick_test_dataset.jsonl +tests/sample_dataset.jsonl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..0b6bd87e1e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.5 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-merge-conflict + - id: check-yaml + args: [--allow-multiple-documents] + - id: check-added-large-files + args: [--maxkb=500] diff --git a/AGENTS.md b/AGENTS.md index 7aef595a36..167a1a865f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,7 +5,8 @@ Instructions for AI coding assistants and developers working on the hermes-agent ## Development Environment ```bash -source .venv/bin/activate # ALWAYS activate before running Python +make setup # First time: creates .venv, installs deps, sets up pre-commit +source .venv/bin/activate ``` ## Project Structure @@ -228,15 +229,27 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER --- -## Testing +## Development Commands + +```bash +make setup # First time: .venv + deps + pre-commit hooks +make check # Lint + test (mirrors CI — run before pushing) +make lint # Ruff check +make fmt # Ruff format + auto-fix +make test # Full test suite (~2500 tests, ~2 min) +make test-fast # Tests with fail-fast (-x) +make test-watch # Rerun tests on file changes +make dev-cli # Auto-restart CLI on file changes +make dev-gateway # Auto-restart gateway on file changes +``` + +For targeted testing, use `pytest` directly: ```bash -source .venv/bin/activate -python -m pytest tests/ -q # Full suite (~2500 tests, ~2 min) python -m pytest tests/test_model_tools.py -q # Toolset resolution python -m pytest tests/test_cli_init.py -q # CLI config loading python -m pytest tests/gateway/ -q # Gateway tests python -m pytest tests/tools/ -q # Tool-level tests ``` -Always run the full suite before pushing changes. +Formatting is enforced by **ruff** (config in `pyproject.toml`). Pre-commit hooks run on every commit. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6ed6c833e4..b4d703db82 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,18 +65,7 @@ If your skill is specialized, community-contributed, or niche, it's better suite ```bash git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git cd hermes-agent - -# Create venv with Python 3.11 -uv venv venv --python 3.11 -export VIRTUAL_ENV="$(pwd)/venv" - -# Install with all extras (messaging, cron, CLI menus, dev tools) -uv pip install -e ".[all,dev]" -uv pip install -e "./mini-swe-agent" -uv pip install -e "./tinker-atropos" - -# Optional: browser tools -npm install +make setup # creates .venv, installs all deps ``` ### Configure for development @@ -90,22 +79,16 @@ touch ~/.hermes/.env echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env ``` -### Run +### Common commands ```bash -# Symlink for global access -mkdir -p ~/.local/bin -ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes - -# Verify -hermes doctor -hermes chat -q "Hello" -``` - -### Run tests - -```bash -pytest tests/ -v +make test # run unit tests +make lint # ruff check +make fmt # ruff format + fix +make check # lint + test (same as CI) +make dev-cli # auto-restart hermes CLI on file changes +make dev-gateway # auto-restart gateway on file changes +make test-watch # rerun tests on file changes ``` --- @@ -227,7 +210,7 @@ User message → AIAgent._run_agent_loop() ## Code Style -- **PEP 8** with practical exceptions (we don't enforce strict line length) +- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically. - **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks. Don't narrate what the code does — `# increment counter` adds nothing - **Error handling**: Catch specific exceptions. Log with `logger.warning()`/`logger.error()` — use `exc_info=True` for unexpected errors so stack traces appear in logs - **Cross-platform**: Never assume Unix. See [Cross-Platform Compatibility](#cross-platform-compatibility) @@ -457,7 +440,7 @@ refactor/description # Code restructuring ### Before submitting -1. **Run tests**: `pytest tests/ -v` +1. **Run checks**: `make check` (lint + test — same as CI) 2. **Test manually**: Run `hermes` and exercise the code path you changed 3. **Check cross-platform impact**: If you touch file I/O, process management, or terminal handling, consider Windows and macOS 4. **Keep PRs focused**: One logical change per PR. Don't mix a bug fix with a refactor with a new feature. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..352d57b990 --- /dev/null +++ b/Makefile @@ -0,0 +1,69 @@ +.DEFAULT_GOAL := help +SHELL := /bin/bash +VENV := .venv +UV := uv + +SRC := run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py \ + tools/ hermes_cli/ gateway/ agent/ cron/ + +# ─── Setup ────────────────────────────────────────────────────────────────────── + +.PHONY: setup sync clean + +setup: ## Full dev setup (venv + deps + pre-commit) + $(UV) venv $(VENV) --python 3.11 + . $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]" + . $(VENV)/bin/activate && $(UV) pip install -e "./mini-swe-agent" + . $(VENV)/bin/activate && pre-commit install + @echo "\n✅ Setup complete. Run: source $(VENV)/bin/activate" + +sync: ## Reinstall deps into existing venv + . $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]" + +clean: ## Remove build artifacts and caches + rm -rf .ruff_cache .mypy_cache .pytest_cache dist build *.egg-info + find . -type d -name __pycache__ -not -path "./.venv/*" -exec rm -rf {} + + +# ─── Quality ──────────────────────────────────────────────────────────────────── + +.PHONY: lint fmt check + +lint: ## Check lint + formatting (no changes) + . $(VENV)/bin/activate && ruff check $(SRC) + . $(VENV)/bin/activate && ruff format --check $(SRC) + +fmt: ## Auto-fix lint + format + . $(VENV)/bin/activate && ruff format $(SRC) + . $(VENV)/bin/activate && ruff check --fix $(SRC) + +check: lint test ## Lint + test (mirrors CI) + +# ─── Test ─────────────────────────────────────────────────────────────────────── + +.PHONY: test test-fast test-watch + +test: ## Run full test suite + . $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short + +test-fast: ## Run tests with fail-fast + . $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short -x + +test-watch: ## Rerun tests on file changes + . $(VENV)/bin/activate && python -m watchfiles "python -m pytest tests/ -q --ignore=tests/integration --tb=short -x" $(SRC) tests/ + +# ─── Dev Servers ──────────────────────────────────────────────────────────────── + +.PHONY: dev-cli dev-gateway + +dev-cli: ## Auto-restart CLI on file changes + . $(VENV)/bin/activate && python -m watchfiles "python -m hermes_cli.main" $(SRC) + +dev-gateway: ## Auto-restart gateway on file changes + . $(VENV)/bin/activate && python -m watchfiles "python -m gateway.run" $(SRC) + +# ─── Misc ─────────────────────────────────────────────────────────────────────── + +.PHONY: help + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' diff --git a/README.md b/README.md index aaa541d5d8..862c3113d1 100644 --- a/README.md +++ b/README.md @@ -95,12 +95,8 @@ Quick start for contributors: ```bash git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git cd hermes-agent -curl -LsSf https://astral.sh/uv/install.sh | sh -uv venv .venv --python 3.11 -source .venv/bin/activate -uv pip install -e ".[all,dev]" -uv pip install -e "./mini-swe-agent" -python -m pytest tests/ -q +make setup # creates .venv, installs everything +make check # lint + test (same as CI) ``` --- diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 57c3c11869..09b5a0c624 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -34,7 +34,7 @@ import logging import os from pathlib import Path from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from openai import OpenAI @@ -43,7 +43,7 @@ from hermes_constants import OPENROUTER_BASE_URL logger = logging.getLogger(__name__) # Default auxiliary models for direct API-key providers (cheap/fast for side tasks) -_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = { +_API_KEY_PROVIDER_AUX_MODELS: dict[str, str] = { "zai": "glm-4.5-flash", "kimi-coding": "kimi-k2-turbo-preview", "minimax": "MiniMax-M2.5-highspeed", @@ -102,7 +102,7 @@ def _convert_content_for_responses(content: Any) -> Any: if not isinstance(content, list): return str(content) if content else "" - converted: List[Dict[str, Any]] = [] + converted: list[dict[str, Any]] = [] for part in content: if not isinstance(part, dict): continue @@ -113,7 +113,7 @@ def _convert_content_for_responses(content: Any) -> Any: # chat.completions nests the URL: {"image_url": {"url": "..."}} image_data = part.get("image_url", {}) url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data) - entry: Dict[str, Any] = {"type": "input_image", "image_url": url} + entry: dict[str, Any] = {"type": "input_image", "image_url": url} # Preserve detail if specified detail = image_data.get("detail") if isinstance(image_data, dict) else None if detail: @@ -148,19 +148,21 @@ class _CodexCompletionsAdapter: # Convert chat.completions multimodal content blocks to Responses # API format (input_text / input_image instead of text / image_url). instructions = "You are a helpful assistant." - input_msgs: List[Dict[str, Any]] = [] + input_msgs: list[dict[str, Any]] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content") or "" if role == "system": instructions = content if isinstance(content, str) else str(content) else: - input_msgs.append({ - "role": role, - "content": _convert_content_for_responses(content), - }) + input_msgs.append( + { + "role": role, + "content": _convert_content_for_responses(content), + } + ) - resp_kwargs: Dict[str, Any] = { + resp_kwargs: dict[str, Any] = { "model": model, "instructions": instructions, "input": input_msgs or [{"role": "user", "content": ""}], @@ -179,18 +181,20 @@ class _CodexCompletionsAdapter: name = fn.get("name") if not name: continue - converted.append({ - "type": "function", - "name": name, - "description": fn.get("description", ""), - "parameters": fn.get("parameters", {}), - }) + converted.append( + { + "type": "function", + "name": name, + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {}), + } + ) if converted: resp_kwargs["tools"] = converted # Stream and collect the response - text_parts: List[str] = [] - tool_calls_raw: List[Any] = [] + text_parts: list[str] = [] + tool_calls_raw: list[Any] = [] usage = None try: @@ -208,14 +212,16 @@ class _CodexCompletionsAdapter: if ptype in ("output_text", "text"): text_parts.append(getattr(part, "text", "")) elif item_type == "function_call": - tool_calls_raw.append(SimpleNamespace( - id=getattr(item, "call_id", ""), - type="function", - function=SimpleNamespace( - name=getattr(item, "name", ""), - arguments=getattr(item, "arguments", "{}"), - ), - )) + tool_calls_raw.append( + SimpleNamespace( + id=getattr(item, "call_id", ""), + type="function", + function=SimpleNamespace( + name=getattr(item, "name", ""), + arguments=getattr(item, "arguments", "{}"), + ), + ) + ) resp_usage = getattr(final, "usage", None) if resp_usage: @@ -285,6 +291,7 @@ class _AsyncCodexCompletionsAdapter: async def create(self, **kwargs) -> Any: import asyncio + return await asyncio.to_thread(self._sync.create, **kwargs) @@ -304,7 +311,7 @@ class AsyncCodexAuxiliaryClient: self.base_url = sync_wrapper.base_url -def _read_nous_auth() -> Optional[dict]: +def _read_nous_auth() -> dict | None: """Read and validate ~/.hermes/auth.json for an active Nous provider. Returns the provider state dict if Nous is active with tokens, @@ -336,10 +343,11 @@ def _nous_base_url() -> str: return os.getenv("NOUS_INFERENCE_BASE_URL", _NOUS_DEFAULT_BASE_URL) -def _read_codex_access_token() -> Optional[str]: +def _read_codex_access_token() -> str | None: """Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json).""" try: from hermes_cli.auth import _read_codex_tokens + data = _read_codex_tokens() tokens = data.get("tokens", {}) access_token = tokens.get("access_token") @@ -351,7 +359,7 @@ def _read_codex_access_token() -> Optional[str]: return None -def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: +def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]: """Try each API-key provider in PROVIDER_REGISTRY order. Returns (client, model) for the first provider whose env var is set, @@ -398,6 +406,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: # ── Provider resolution helpers ───────────────────────────────────────────── + def _get_auxiliary_provider(task: str = "") -> str: """Read the provider override for a specific auxiliary task. @@ -413,16 +422,15 @@ def _get_auxiliary_provider(task: str = "") -> str: return "auto" -def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]: +def _try_openrouter() -> tuple[OpenAI | None, str | None]: or_key = os.getenv("OPENROUTER_API_KEY") if not or_key: return None, None logger.debug("Auxiliary client: OpenRouter") - return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, - default_headers=_OR_HEADERS), _OPENROUTER_MODEL + return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, default_headers=_OR_HEADERS), _OPENROUTER_MODEL -def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]: +def _try_nous() -> tuple[OpenAI | None, str | None]: nous = _read_nous_auth() if not nous: return None, None @@ -435,7 +443,7 @@ def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]: ) -def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]: +def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]: custom_base = os.getenv("OPENAI_BASE_URL") custom_key = os.getenv("OPENAI_API_KEY") if not custom_base or not custom_key: @@ -445,7 +453,7 @@ def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]: return OpenAI(api_key=custom_key, base_url=custom_base), model -def _try_codex() -> Tuple[Optional[Any], Optional[str]]: +def _try_codex() -> tuple[Any | None, str | None]: codex_token = _read_codex_access_token() if not codex_token: return None, None @@ -454,7 +462,7 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]: return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL -def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]: +def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]: """Resolve a specific forced provider. Returns (None, None) if creds missing.""" if forced == "openrouter": client, model = _try_openrouter() @@ -488,10 +496,9 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st return None, None -def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: +def _resolve_auto() -> tuple[OpenAI | None, str | None]: """Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None.""" - for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, - _try_codex, _resolve_api_key_provider): + for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider): client, model = try_fn() if client is not None: return client, model @@ -501,7 +508,8 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: # ── Public API ────────────────────────────────────────────────────────────── -def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]: + +def get_text_auxiliary_client(task: str = "") -> tuple[OpenAI | None, str | None]: """Return (client, default_model_slug) for text-only auxiliary tasks. Args: @@ -544,7 +552,7 @@ def get_async_text_auxiliary_client(task: str = ""): return AsyncOpenAI(**async_kwargs), model -def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]: +def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]: """Return (client, default_model_slug) for vision/multimodal auxiliary tasks. Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise @@ -564,8 +572,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]: # back to the user's custom endpoint. Many local models (Qwen-VL, # LLaVA, Pixtral, etc.) support vision — skipping them entirely # caused silent failures for local-only users. - for try_fn in (_try_openrouter, _try_nous, _try_codex, - _try_custom_endpoint): + for try_fn in (_try_openrouter, _try_nous, _try_codex, _try_custom_endpoint): client, model = try_fn() if client is not None: return client, model @@ -575,7 +582,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]: def get_auxiliary_extra_body() -> dict: """Return extra_body kwargs for auxiliary API calls. - + Includes Nous Portal product tags when the auxiliary client is backed by Nous Portal. Returns empty dict otherwise. """ @@ -584,7 +591,7 @@ def get_auxiliary_extra_body() -> dict: def auxiliary_max_tokens_param(value: int) -> dict: """Return the correct max tokens kwarg for the auxiliary client's provider. - + OpenRouter and local models use 'max_tokens'. Direct OpenAI with newer models (gpt-4o, o-series, gpt-5+) requires 'max_completion_tokens'. The Codex adapter translates max_tokens internally, so we use max_tokens @@ -593,8 +600,6 @@ def auxiliary_max_tokens_param(value: int) -> dict: custom_base = os.getenv("OPENAI_BASE_URL", "") or_key = os.getenv("OPENROUTER_API_KEY") # Only use max_completion_tokens for direct OpenAI custom endpoints - if (not or_key - and _read_nous_auth() is None - and "api.openai.com" in custom_base.lower()): + if not or_key and _read_nous_auth() is None and "api.openai.com" in custom_base.lower(): return {"max_completion_tokens": value} return {"max_tokens": value} diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 01aa2af804..f618ac21f1 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -7,12 +7,12 @@ protecting head and tail context. import logging import os -from typing import Any, Dict, List, Optional +from typing import Any from agent.auxiliary_client import get_text_auxiliary_client from agent.model_metadata import ( - get_model_context_length, estimate_messages_tokens_rough, + get_model_context_length, ) logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class ContextCompressor: self.client, default_model = get_text_auxiliary_client("compression") self.summary_model = summary_model_override or default_model - def update_from_response(self, usage: Dict[str, Any]): + def update_from_response(self, usage: dict[str, Any]): """Update tracked token usage from API response.""" self.last_prompt_tokens = usage.get("prompt_tokens", 0) self.last_completion_tokens = usage.get("completion_tokens", 0) @@ -67,12 +67,12 @@ class ContextCompressor: tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens return tokens >= self.threshold_tokens - def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool: + def should_compress_preflight(self, messages: list[dict[str, Any]]) -> bool: """Quick pre-flight check using rough estimate (before API call).""" rough_estimate = estimate_messages_tokens_rough(messages) return rough_estimate >= self.threshold_tokens - def get_status(self) -> Dict[str, Any]: + def get_status(self) -> dict[str, Any]: """Get current compression status for display/logging.""" return { "last_prompt_tokens": self.last_prompt_tokens, @@ -82,7 +82,7 @@ class ContextCompressor: "compression_count": self.compression_count, } - def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]: + def _generate_summary(self, turns_to_summarize: list[dict[str, Any]]) -> str | None: """Generate a concise summary of conversation turns. Tries the auxiliary model first, then falls back to the user's main @@ -140,7 +140,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" logging.warning(f"Main model summary also failed: {fallback_err}") # 3. All models failed — return None so the caller drops turns without a summary - logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.") + logging.warning( + "Context compression: no model available for summary. Middle turns will be dropped without summary." + ) return None def _call_summary_model(self, client, model: str, prompt: str) -> str: @@ -186,12 +188,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" # Don't fallback to the same provider that just failed from hermes_constants import OPENROUTER_BASE_URL + if custom_base.rstrip("/") == OPENROUTER_BASE_URL.rstrip("/"): return None, None model = os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or self.model try: from openai import OpenAI as _OpenAI + client = _OpenAI(api_key=custom_key, base_url=custom_base) logger.debug("Built fallback auxiliary client: %s via %s", model, custom_base) return client, model @@ -210,7 +214,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" return tc.get("id", "") return getattr(tc, "id", "") or "" - def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _sanitize_tool_pairs(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Fix orphaned tool_call / tool_result pairs after compression. Two failure modes: @@ -243,8 +247,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" orphaned_results = result_call_ids - surviving_call_ids if orphaned_results: messages = [ - m for m in messages - if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results) + m for m in messages if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results) ] if not self.quiet_mode: logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results)) @@ -252,25 +255,27 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" # 2. Add stub results for assistant tool_calls whose results were dropped missing_results = surviving_call_ids - result_call_ids if missing_results: - patched: List[Dict[str, Any]] = [] + patched: list[dict[str, Any]] = [] for msg in messages: patched.append(msg) if msg.get("role") == "assistant": for tc in msg.get("tool_calls") or []: cid = self._get_tool_call_id(tc) if cid in missing_results: - patched.append({ - "role": "tool", - "content": "[Result from earlier conversation — see context summary above]", - "tool_call_id": cid, - }) + patched.append( + { + "role": "tool", + "content": "[Result from earlier conversation — see context summary above]", + "tool_call_id": cid, + } + ) messages = patched if not self.quiet_mode: logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results)) return messages - def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int: + def _align_boundary_forward(self, messages: list[dict[str, Any]], idx: int) -> int: """Push a compress-start boundary forward past any orphan tool results. If ``messages[idx]`` is a tool result, slide forward until we hit a @@ -280,7 +285,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" idx += 1 return idx - def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int: + def _align_boundary_backward(self, messages: list[dict[str, Any]], idx: int) -> int: """Pull a compress-end boundary backward to avoid splitting a tool_call / result group. @@ -298,7 +303,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" idx -= 1 return idx - def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]: + def compress(self, messages: list[dict[str, Any]], current_tokens: int = None) -> list[dict[str, Any]]: """Compress conversation messages by summarizing middle turns. Keeps first N + last N turns, summarizes everything in between. @@ -308,7 +313,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" n_messages = len(messages) if n_messages <= self.protect_first_n + self.protect_last_n + 1: if not self.quiet_mode: - print(f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})") + print( + f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})" + ) return messages compress_start = self.protect_first_n @@ -323,14 +330,20 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" return messages turns_to_summarize = messages[compress_start:compress_end] - display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages) + display_tokens = ( + current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages) + ) if not self.quiet_mode: - print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)") - print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})") + print( + f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)" + ) + print( + f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent * 100:.0f}% = {self.threshold_tokens:,})" + ) if not self.quiet_mode: - print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)") + print(f" 🗜️ Summarizing turns {compress_start + 1}-{compress_end} ({len(turns_to_summarize)} turns)") summary = self._generate_summary(turns_to_summarize) @@ -338,7 +351,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" for i in range(compress_start): msg = messages[i].copy() if i == 0 and msg.get("role") == "system" and self.compression_count == 0: - msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]" + msg["content"] = ( + msg.get("content") or "" + ) + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]" compressed.append(msg) if summary: diff --git a/agent/display.py b/agent/display.py index 17595ce279..864cfb6353 100644 --- a/agent/display.py +++ b/agent/display.py @@ -6,7 +6,6 @@ Used by AIAgent._execute_tool_calls for CLI feedback. import json import os -import random import sys import threading import time @@ -20,19 +19,31 @@ _RESET = "\033[0m" # Tool preview (one-line summary of a tool call's primary argument) # ========================================================================= + def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: """Build a short preview of a tool call's primary argument for display.""" primary_args = { - "terminal": "command", "web_search": "query", "web_extract": "urls", - "read_file": "path", "write_file": "path", "patch": "path", - "search_files": "pattern", "browser_navigate": "url", - "browser_click": "ref", "browser_type": "text", - "image_generate": "prompt", "text_to_speech": "text", - "vision_analyze": "question", "mixture_of_agents": "user_prompt", - "skill_view": "name", "skills_list": "category", + "terminal": "command", + "web_search": "query", + "web_extract": "urls", + "read_file": "path", + "write_file": "path", + "patch": "path", + "search_files": "pattern", + "browser_navigate": "url", + "browser_click": "ref", + "browser_type": "text", + "image_generate": "prompt", + "text_to_speech": "text", + "vision_analyze": "question", + "mixture_of_agents": "user_prompt", + "skill_view": "name", + "skills_list": "category", "schedule_cronjob": "name", - "execute_code": "code", "delegate_task": "goal", - "clarify": "question", "skill_manage": "name", + "execute_code": "code", + "delegate_task": "goal", + "clarify": "question", + "skill_manage": "name", } if tool_name == "process": @@ -61,18 +72,18 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: if tool_name == "session_search": query = args.get("query", "") - return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\"" + return f'recall: "{query[:25]}{"..." if len(query) > 25 else ""}"' if tool_name == "memory": action = args.get("action", "") target = args.get("target", "") if action == "add": content = args.get("content", "") - return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\"" + return f'+{target}: "{content[:25]}{"..." if len(content) > 25 else ""}"' elif action == "replace": - return f"~{target}: \"{args.get('old_text', '')[:20]}\"" + return f'~{target}: "{args.get("old_text", "")[:20]}"' elif action == "remove": - return f"-{target}: \"{args.get('old_text', '')[:20]}\"" + return f'-{target}: "{args.get("old_text", "")[:20]}"' return action if tool_name == "send_message": @@ -80,7 +91,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: msg = args.get("message", "") if len(msg) > 20: msg = msg[:17] + "..." - return f"to {target}: \"{msg}\"" + return f'to {target}: "{msg}"' if tool_name.startswith("rl_"): rl_previews = { @@ -115,7 +126,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: if not preview: return None if len(preview) > max_len: - preview = preview[:max_len - 3] + "..." + preview = preview[: max_len - 3] + "..." return preview @@ -123,41 +134,74 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str: # KawaiiSpinner # ========================================================================= + class KawaiiSpinner: """Animated spinner with kawaii faces for CLI feedback during tool execution.""" SPINNERS = { - 'dots': ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'], - 'bounce': ['⠁', '⠂', '⠄', '⡀', '⢀', '⠠', '⠐', '⠈'], - 'grow': ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█', '▇', '▆', '▅', '▄', '▃', '▂'], - 'arrows': ['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'], - 'star': ['✶', '✷', '✸', '✹', '✺', '✹', '✸', '✷'], - 'moon': ['🌑', '🌒', '🌓', '🌔', '🌕', '🌖', '🌗', '🌘'], - 'pulse': ['◜', '◠', '◝', '◞', '◡', '◟'], - 'brain': ['🧠', '💭', '💡', '✨', '💫', '🌟', '💡', '💭'], - 'sparkle': ['⁺', '˚', '*', '✧', '✦', '✧', '*', '˚'], + "dots": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], + "bounce": ["⠁", "⠂", "⠄", "⡀", "⢀", "⠠", "⠐", "⠈"], + "grow": ["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█", "▇", "▆", "▅", "▄", "▃", "▂"], + "arrows": ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"], + "star": ["✶", "✷", "✸", "✹", "✺", "✹", "✸", "✷"], + "moon": ["🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"], + "pulse": ["◜", "◠", "◝", "◞", "◡", "◟"], + "brain": ["🧠", "💭", "💡", "✨", "💫", "🌟", "💡", "💭"], + "sparkle": ["⁺", "˚", "*", "✧", "✦", "✧", "*", "˚"], } KAWAII_WAITING = [ - "(。◕‿◕。)", "(◕‿◕✿)", "٩(◕‿◕。)۶", "(✿◠‿◠)", "( ˘▽˘)っ", - "♪(´ε` )", "(◕ᴗ◕✿)", "ヾ(^∇^)", "(≧◡≦)", "(★ω★)", + "(。◕‿◕。)", + "(◕‿◕✿)", + "٩(◕‿◕。)۶", + "(✿◠‿◠)", + "( ˘▽˘)っ", + "♪(´ε` )", + "(◕ᴗ◕✿)", + "ヾ(^∇^)", + "(≧◡≦)", + "(★ω★)", ] KAWAII_THINKING = [ - "(。•́︿•̀。)", "(◔_◔)", "(¬‿¬)", "( •_•)>⌐■-■", "(⌐■_■)", - "(´・_・`)", "◉_◉", "(°ロ°)", "( ˘⌣˘)♡", "ヽ(>∀<☆)☆", - "٩(๑❛ᴗ❛๑)۶", "(⊙_⊙)", "(¬_¬)", "( ͡° ͜ʖ ͡°)", "ಠ_ಠ", + "(。•́︿•̀。)", + "(◔_◔)", + "(¬‿¬)", + "( •_•)>⌐■-■", + "(⌐■_■)", + "(´・_・`)", + "◉_◉", + "(°ロ°)", + "( ˘⌣˘)♡", + "ヽ(>∀<☆)☆", + "٩(๑❛ᴗ❛๑)۶", + "(⊙_⊙)", + "(¬_¬)", + "( ͡° ͜ʖ ͡°)", + "ಠ_ಠ", ] THINKING_VERBS = [ - "pondering", "contemplating", "musing", "cogitating", "ruminating", - "deliberating", "mulling", "reflecting", "processing", "reasoning", - "analyzing", "computing", "synthesizing", "formulating", "brainstorming", + "pondering", + "contemplating", + "musing", + "cogitating", + "ruminating", + "deliberating", + "mulling", + "reflecting", + "processing", + "reasoning", + "analyzing", + "computing", + "synthesizing", + "formulating", + "brainstorming", ] - def __init__(self, message: str = "", spinner_type: str = 'dots'): + def __init__(self, message: str = "", spinner_type: str = "dots"): self.message = message - self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots']) + self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS["dots"]) self.running = False self.thread = None self.frame_idx = 0 @@ -167,7 +211,7 @@ class KawaiiSpinner: # child agents can replace sys.stdout with a black hole. self._out = sys.stdout - def _write(self, text: str, end: str = '\n', flush: bool = False): + def _write(self, text: str, end: str = "\n", flush: bool = False): """Write to the stdout captured at spinner creation time.""" try: self._out.write(text + end) @@ -185,7 +229,7 @@ class KawaiiSpinner: elapsed = time.time() - self.start_time line = f" {frame} {self.message} ({elapsed:.1f}s)" pad = max(self.last_line_len - len(line), 0) - self._write(f"\r{line}{' ' * pad}", end='', flush=True) + self._write(f"\r{line}{' ' * pad}", end="", flush=True) self.last_line_len = len(line) self.frame_idx += 1 time.sleep(0.12) @@ -216,7 +260,7 @@ class KawaiiSpinner: # Clear spinner line with spaces (not \033[K) to avoid garbled escape # codes when prompt_toolkit's patch_stdout is active — same approach # as stop(). Then print text; spinner redraws on next tick. - blanks = ' ' * max(self.last_line_len + 5, 40) + blanks = " " * max(self.last_line_len + 5, 40) self._write(f"\r{blanks}\r {text}", flush=True) def stop(self, final_message: str = None): @@ -225,8 +269,8 @@ class KawaiiSpinner: self.thread.join(timeout=0.5) # Clear the spinner line with spaces instead of \033[K to avoid # garbled escape codes when prompt_toolkit's patch_stdout is active. - blanks = ' ' * max(self.last_line_len + 5, 40) - self._write(f"\r{blanks}\r", end='', flush=True) + blanks = " " * max(self.last_line_len + 5, 40) + self._write(f"\r{blanks}\r", end="", flush=True) if final_message: self._write(f" {final_message}", flush=True) @@ -244,38 +288,110 @@ class KawaiiSpinner: # ========================================================================= KAWAII_SEARCH = [ - "♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ", - "٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/", + "♪(´ε` )", + "(。◕‿◕。)", + "ヾ(^∇^)", + "(◕ᴗ◕✿)", + "( ˘▽˘)っ", + "٩(◕‿◕。)۶", + "(✿◠‿◠)", + "♪~(´ε` )", + "(ノ´ヮ`)ノ*:・゚✧", + "\(◎o◎)/", ] KAWAII_READ = [ - "φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)", - "ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ", + "φ(゜▽゜*)♪", + "( ˘▽˘)っ", + "(⌐■_■)", + "٩(。•́‿•̀。)۶", + "(◕‿◕✿)", + "ヾ(@⌒ー⌒@)ノ", + "(✧ω✧)", + "♪(๑ᴖ◡ᴖ๑)♪", + "(≧◡≦)", + "( ´ ▽ ` )ノ", ] KAWAII_TERMINAL = [ - "ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و", - "┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/", + "ヽ(>∀<☆)ノ", + "(ノ°∀°)ノ", + "٩(^ᴗ^)۶", + "ヾ(⌐■_■)ノ♪", + "(•̀ᴗ•́)و", + "┗(^0^)┓", + "(`・ω・´)", + "\( ̄▽ ̄)/", + "(ง •̀_•́)ง", + "ヽ(´▽`)/", ] KAWAII_BROWSER = [ - "(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?", - "ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/", + "(ノ°∀°)ノ", + "(☞゚ヮ゚)☞", + "( ͡° ͜ʖ ͡°)", + "┌( ಠ_ಠ)┘", + "(⊙_⊙)?", + "ヾ(•ω•`)o", + "( ̄ω ̄)", + "( ˇωˇ )", + "(ᵔᴥᵔ)", + "\(◎o◎)/", ] KAWAII_CREATE = [ - "✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡", - "✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°", + "✧*。٩(ˊᗜˋ*)و✧", + "(ノ◕ヮ◕)ノ*:・゚✧", + "ヽ(>∀<☆)ノ", + "٩(♡ε♡)۶", + "(◕‿◕)♡", + "✿◕ ‿ ◕✿", + "(*≧▽≦)", + "ヾ(^-^)ノ", + "(☆▽☆)", + "°˖✧◝(⁰▿⁰)◜✧˖°", ] KAWAII_SKILL = [ - "ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ", - "(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)", - "ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/", - "(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)", + "ヾ(@⌒ー⌒@)ノ", + "(๑˃ᴗ˂)ﻭ", + "٩(◕‿◕。)۶", + "(✿╹◡╹)", + "ヽ(・∀・)ノ", + "(ノ´ヮ`)ノ*:・゚✧", + "♪(๑ᴖ◡ᴖ๑)♪", + "(◠‿◠)", + "٩(ˊᗜˋ*)و", + "(^▽^)", + "ヾ(^∇^)", + "(★ω★)/", + "٩(。•́‿•̀。)۶", + "(◕ᴗ◕✿)", + "\(◎o◎)/", + "(✧ω✧)", + "ヽ(>∀<☆)ノ", + "( ˘▽˘)っ", + "(≧◡≦) ♡", + "ヾ( ̄▽ ̄)", ] KAWAII_THINK = [ - "(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)", - "(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)", + "(っ°Д°;)っ", + "(;′⌒`)", + "(・_・ヾ", + "( ´_ゝ`)", + "( ̄ヘ ̄)", + "(。-`ω´-)", + "( ˘︹˘ )", + "(¬_¬)", + "ヽ(ー_ー )ノ", + "(;一_一)", ] KAWAII_GENERIC = [ - "♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)", - "(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)", + "♪(´ε` )", + "(◕‿◕✿)", + "ヾ(^∇^)", + "٩(◕‿◕。)۶", + "(✿◠‿◠)", + "(ノ´ヮ`)ノ*:・゚✧", + "ヽ(>∀<☆)ノ", + "(☆▽☆)", + "( ˘▽˘)っ", + "(≧◡≦)", ] @@ -283,6 +399,7 @@ KAWAII_GENERIC = [ # Cute tool message (completion line that replaces the spinner) # ========================================================================= + def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]: """Inspect a tool result string for signs of failure. @@ -321,7 +438,10 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str] def get_cute_tool_message( - tool_name: str, args: dict, duration: float, result: str | None = None, + tool_name: str, + args: dict, + duration: float, + result: str | None = None, ) -> str: """Generate a formatted tool completion line for CLI quiet mode. @@ -335,11 +455,11 @@ def get_cute_tool_message( def _trunc(s, n=40): s = str(s) - return (s[:n-3] + "...") if len(s) > n else s + return (s[: n - 3] + "...") if len(s) > n else s def _path(p, n=35): p = str(p) - return ("..." + p[-(n-3):]) if len(p) > n else p + return ("..." + p[-(n - 3) :]) if len(p) > n else p def _wrap(line: str) -> str: """Append failure suffix when the tool failed.""" @@ -354,7 +474,7 @@ def get_cute_tool_message( if urls: url = urls[0] if isinstance(urls, list) else str(urls) domain = url.replace("https://", "").replace("http://", "").split("/")[0] - extra = f" +{len(urls)-1}" if len(urls) > 1 else "" + extra = f" +{len(urls) - 1}" if len(urls) > 1 else "" return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}") return _wrap(f"┊ 📄 fetch pages {dur}") if tool_name == "web_crawl": @@ -366,8 +486,15 @@ def get_cute_tool_message( if tool_name == "process": action = args.get("action", "?") sid = args.get("session_id", "")[:12] - labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}", - "wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"} + labels = { + "list": "ls processes", + "poll": f"poll {sid}", + "log": f"log {sid}", + "wait": f"wait {sid}", + "kill": f"kill {sid}", + "write": f"write {sid}", + "submit": f"submit {sid}", + } return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}") if tool_name == "read_file": return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}") @@ -390,7 +517,7 @@ def get_cute_tool_message( if tool_name == "browser_click": return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}") if tool_name == "browser_type": - return _wrap(f"┊ ⌨️ type \"{_trunc(args.get('text', ''), 30)}\" {dur}") + return _wrap(f'┊ ⌨️ type "{_trunc(args.get("text", ""), 30)}" {dur}') if tool_name == "browser_scroll": d = args.get("direction", "down") arrow = {"down": "↓", "up": "↑", "right": "→", "left": "←"}.get(d, "↓") @@ -415,16 +542,16 @@ def get_cute_tool_message( else: return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}") if tool_name == "session_search": - return _wrap(f"┊ 🔍 recall \"{_trunc(args.get('query', ''), 35)}\" {dur}") + return _wrap(f'┊ 🔍 recall "{_trunc(args.get("query", ""), 35)}" {dur}') if tool_name == "memory": action = args.get("action", "?") target = args.get("target", "") if action == "add": - return _wrap(f"┊ 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}") + return _wrap(f'┊ 🧠 memory +{target}: "{_trunc(args.get("content", ""), 30)}" {dur}') elif action == "replace": - return _wrap(f"┊ 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}") + return _wrap(f'┊ 🧠 memory ~{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}') elif action == "remove": - return _wrap(f"┊ 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}") + return _wrap(f'┊ 🧠 memory -{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}') return _wrap(f"┊ 🧠 memory {action} {dur}") if tool_name == "skills_list": return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}") @@ -439,7 +566,7 @@ def get_cute_tool_message( if tool_name == "mixture_of_agents": return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}") if tool_name == "send_message": - return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}") + return _wrap(f'┊ 📨 send {args.get("target", "?")}: "{_trunc(args.get("message", ""), 25)}" {dur}') if tool_name == "schedule_cronjob": return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}") if tool_name == "list_cronjobs": @@ -448,11 +575,16 @@ def get_cute_tool_message( return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}") if tool_name.startswith("rl_"): rl = { - "rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}", - "rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}", - "rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}", - "rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}", - "rl_list_runs": "list runs", "rl_test_inference": "test inference", + "rl_list_environments": "list envs", + "rl_select_environment": f"select {args.get('name', '')}", + "rl_get_current_config": "get config", + "rl_edit_config": f"set {args.get('field', '?')}", + "rl_start_training": "start training", + "rl_check_status": f"status {args.get('run_id', '?')[:12]}", + "rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", + "rl_get_results": f"results {args.get('run_id', '?')[:12]}", + "rl_list_runs": "list runs", + "rl_test_inference": "test inference", } return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}") if tool_name == "execute_code": diff --git a/agent/insights.py b/agent/insights.py index df3b9e85c8..ed4d07412c 100644 --- a/agent/insights.py +++ b/agent/insights.py @@ -20,7 +20,7 @@ import json import time from collections import Counter, defaultdict from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any # ========================================================================= # Model pricing (USD per million tokens) — approximate as of early 2026 @@ -81,7 +81,7 @@ def _has_known_pricing(model_name: str) -> bool: return _get_pricing(model_name) is not _DEFAULT_PRICING -def _get_pricing(model_name: str) -> Dict[str, float]: +def _get_pricing(model_name: str) -> dict[str, float]: """Look up pricing for a model. Uses fuzzy matching on model name. Returns _DEFAULT_PRICING (zero cost) for unknown/custom models — @@ -150,7 +150,7 @@ def _format_duration(seconds: float) -> str: return f"{days:.1f}d" -def _bar_chart(values: List[int], max_width: int = 20) -> List[str]: +def _bar_chart(values: list[int], max_width: int = 20) -> list[str]: """Create simple horizontal bar chart strings from values.""" peak = max(values) if values else 1 if peak == 0: @@ -176,7 +176,7 @@ class InsightsEngine: self.db = db self._conn = db._conn - def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]: + def generate(self, days: int = 30, source: str = None) -> dict[str, Any]: """ Generate a complete insights report. @@ -233,10 +233,11 @@ class InsightsEngine: # ========================================================================= # Columns we actually need (skip system_prompt, model_config blobs) - _SESSION_COLS = ("id, source, model, started_at, ended_at, " - "message_count, tool_call_count, input_tokens, output_tokens") + _SESSION_COLS = ( + "id, source, model, started_at, ended_at, message_count, tool_call_count, input_tokens, output_tokens" + ) - def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]: + def _get_sessions(self, cutoff: float, source: str = None) -> list[dict]: """Fetch sessions within the time window.""" if source: cursor = self._conn.execute( @@ -254,7 +255,7 @@ class InsightsEngine: ) return [dict(row) for row in cursor.fetchall()] - def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]: + def _get_tool_usage(self, cutoff: float, source: str = None) -> list[dict]: """Get tool call counts from messages. Uses two sources: @@ -341,12 +342,9 @@ class InsightsEngine: tool_counts = merged # Convert to the expected format - return [ - {"tool_name": name, "count": count} - for name, count in tool_counts.most_common() - ] + return [{"tool_name": name, "count": count} for name, count in tool_counts.most_common()] - def _get_message_stats(self, cutoff: float, source: str = None) -> Dict: + def _get_message_stats(self, cutoff: float, source: str = None) -> dict: """Get aggregate message statistics.""" if source: cursor = self._conn.execute( @@ -373,16 +371,22 @@ class InsightsEngine: (cutoff,), ) row = cursor.fetchone() - return dict(row) if row else { - "total_messages": 0, "user_messages": 0, - "assistant_messages": 0, "tool_messages": 0, - } + return ( + dict(row) + if row + else { + "total_messages": 0, + "user_messages": 0, + "assistant_messages": 0, + "tool_messages": 0, + } + ) # ========================================================================= # Computation # ========================================================================= - def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict: + def _compute_overview(self, sessions: list[dict], message_stats: dict) -> dict: """Compute high-level overview statistics.""" total_input = sum(s.get("input_tokens") or 0 for s in sessions) total_output = sum(s.get("output_tokens") or 0 for s in sessions) @@ -442,12 +446,18 @@ class InsightsEngine: "models_without_pricing": sorted(models_without_pricing), } - def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]: + def _compute_model_breakdown(self, sessions: list[dict]) -> list[dict]: """Break down usage by model.""" - model_data = defaultdict(lambda: { - "sessions": 0, "input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "tool_calls": 0, "cost": 0.0, - }) + model_data = defaultdict( + lambda: { + "sessions": 0, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "tool_calls": 0, + "cost": 0.0, + } + ) for s in sessions: model = s.get("model") or "unknown" @@ -464,20 +474,23 @@ class InsightsEngine: d["cost"] += _estimate_cost(model, inp, out) d["has_pricing"] = _has_known_pricing(model) - result = [ - {"model": model, **data} - for model, data in model_data.items() - ] + result = [{"model": model, **data} for model, data in model_data.items()] # Sort by tokens first, fall back to session count when tokens are 0 result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True) return result - def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]: + def _compute_platform_breakdown(self, sessions: list[dict]) -> list[dict]: """Break down usage by platform/source.""" - platform_data = defaultdict(lambda: { - "sessions": 0, "messages": 0, "input_tokens": 0, - "output_tokens": 0, "total_tokens": 0, "tool_calls": 0, - }) + platform_data = defaultdict( + lambda: { + "sessions": 0, + "messages": 0, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "tool_calls": 0, + } + ) for s in sessions: source = s.get("source") or "unknown" @@ -491,27 +504,26 @@ class InsightsEngine: d["total_tokens"] += inp + out d["tool_calls"] += s.get("tool_call_count") or 0 - result = [ - {"platform": platform, **data} - for platform, data in platform_data.items() - ] + result = [{"platform": platform, **data} for platform, data in platform_data.items()] result.sort(key=lambda x: x["sessions"], reverse=True) return result - def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]: + def _compute_tool_breakdown(self, tool_usage: list[dict]) -> list[dict]: """Process tool usage data into a ranked list with percentages.""" total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0 result = [] for t in tool_usage: pct = (t["count"] / total_calls * 100) if total_calls else 0 - result.append({ - "tool": t["tool_name"], - "count": t["count"], - "percentage": pct, - }) + result.append( + { + "tool": t["tool_name"], + "count": t["count"], + "percentage": pct, + } + ) return result - def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict: + def _compute_activity_patterns(self, sessions: list[dict]) -> dict: """Analyze activity patterns by day of week and hour.""" day_counts = Counter() # 0=Monday ... 6=Sunday hour_counts = Counter() @@ -527,15 +539,9 @@ class InsightsEngine: daily_counts[dt.strftime("%Y-%m-%d")] += 1 day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] - day_breakdown = [ - {"day": day_names[i], "count": day_counts.get(i, 0)} - for i in range(7) - ] + day_breakdown = [{"day": day_names[i], "count": day_counts.get(i, 0)} for i in range(7)] - hour_breakdown = [ - {"hour": i, "count": hour_counts.get(i, 0)} - for i in range(24) - ] + hour_breakdown = [{"hour": i, "count": hour_counts.get(i, 0)} for i in range(24)] # Busiest day and hour busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None @@ -569,37 +575,40 @@ class InsightsEngine: "max_streak": max_streak, } - def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]: + def _compute_top_sessions(self, sessions: list[dict]) -> list[dict]: """Find notable sessions (longest, most messages, most tokens).""" top = [] # Longest by duration - sessions_with_duration = [ - s for s in sessions - if s.get("started_at") and s.get("ended_at") - ] + sessions_with_duration = [s for s in sessions if s.get("started_at") and s.get("ended_at")] if sessions_with_duration: longest = max( sessions_with_duration, - key=lambda s: (s["ended_at"] - s["started_at"]), + key=lambda s: s["ended_at"] - s["started_at"], ) dur = longest["ended_at"] - longest["started_at"] - top.append({ - "label": "Longest session", - "session_id": longest["id"][:16], - "value": _format_duration(dur), - "date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"), - }) + top.append( + { + "label": "Longest session", + "session_id": longest["id"][:16], + "value": _format_duration(dur), + "date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"), + } + ) # Most messages most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0) if (most_msgs.get("message_count") or 0) > 0: - top.append({ - "label": "Most messages", - "session_id": most_msgs["id"][:16], - "value": f"{most_msgs['message_count']} msgs", - "date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?", - }) + top.append( + { + "label": "Most messages", + "session_id": most_msgs["id"][:16], + "value": f"{most_msgs['message_count']} msgs", + "date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") + if most_msgs.get("started_at") + else "?", + } + ) # Most tokens most_tokens = max( @@ -608,22 +617,30 @@ class InsightsEngine: ) token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0) if token_total > 0: - top.append({ - "label": "Most tokens", - "session_id": most_tokens["id"][:16], - "value": f"{token_total:,} tokens", - "date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?", - }) + top.append( + { + "label": "Most tokens", + "session_id": most_tokens["id"][:16], + "value": f"{token_total:,} tokens", + "date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") + if most_tokens.get("started_at") + else "?", + } + ) # Most tool calls most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0) if (most_tools.get("tool_call_count") or 0) > 0: - top.append({ - "label": "Most tool calls", - "session_id": most_tools["id"][:16], - "value": f"{most_tools['tool_call_count']} calls", - "date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?", - }) + top.append( + { + "label": "Most tool calls", + "session_id": most_tools["id"][:16], + "value": f"{most_tools['tool_call_count']} calls", + "date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") + if most_tools.get("started_at") + else "?", + } + ) return top @@ -631,7 +648,7 @@ class InsightsEngine: # Formatting # ========================================================================= - def format_terminal(self, report: Dict) -> str: + def format_terminal(self, report: dict) -> str: """Format the insights report for terminal display (CLI).""" if report.get("empty"): days = report.get("days", 30) @@ -669,13 +686,17 @@ class InsightsEngine: lines.append(" " + "─" * 56) lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}") lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}") - lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}") + lines.append( + f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}" + ) cost_str = f"${o['estimated_cost']:.2f}" if o.get("models_without_pricing"): cost_str += " *" lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}") if o["total_hours"] > 0: - lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}") + lines.append( + f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}" + ) lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}") lines.append("") @@ -692,7 +713,7 @@ class InsightsEngine: cost_cell = " N/A" lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}") if o.get("models_without_pricing"): - lines.append(f" * Cost N/A for custom/self-hosted models") + lines.append(" * Cost N/A for custom/self-hosted models") lines.append("") # Platform breakdown @@ -758,7 +779,7 @@ class InsightsEngine: return "\n".join(lines) - def format_gateway(self, report: Dict) -> str: + def format_gateway(self, report: dict) -> str: """Format the insights report for gateway/messaging (shorter).""" if report.get("empty"): days = report.get("days", 30) @@ -771,14 +792,20 @@ class InsightsEngine: lines.append(f"📊 **Hermes Insights** — Last {days} days\n") # Overview - lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}") - lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})") + lines.append( + f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}" + ) + lines.append( + f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})" + ) cost_note = "" if o.get("models_without_pricing"): cost_note = " _(excludes custom/self-hosted models)_" lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}") if o["total_hours"] > 0: - lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}") + lines.append( + f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}" + ) lines.append("") # Models (top 5) @@ -786,7 +813,9 @@ class InsightsEngine: lines.append("**🤖 Models:**") for m in report["models"][:5]: cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A" - lines.append(f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}") + lines.append( + f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}" + ) lines.append("") # Platforms (if multi-platform) @@ -809,9 +838,13 @@ class InsightsEngine: hr = act["busiest_hour"]["hour"] ampm = "AM" if hr < 12 else "PM" display_hr = hr % 12 or 12 - lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)") + lines.append( + f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)" + ) if act.get("active_days"): - lines.append(f"**Active days:** {act['active_days']}", ) + lines.append( + f"**Active days:** {act['active_days']}", + ) if act.get("max_streak", 0) > 1: lines.append(f"**Best streak:** {act['max_streak']} consecutive days") diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 3b2ab9d0f1..5140b3b6ee 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -9,7 +9,7 @@ import os import re import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import requests import yaml @@ -18,7 +18,7 @@ from hermes_constants import OPENROUTER_MODELS_URL logger = logging.getLogger(__name__) -_model_metadata_cache: Dict[str, Dict[str, Any]] = {} +_model_metadata_cache: dict[str, dict[str, Any]] = {} _model_metadata_cache_time: float = 0 _MODEL_CACHE_TTL = 3600 @@ -63,7 +63,7 @@ DEFAULT_CONTEXT_LENGTHS = { } -def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]: +def fetch_model_metadata(force_refresh: bool = False) -> dict[str, dict[str, Any]]: """Fetch model metadata from OpenRouter (cached for 1 hour).""" global _model_metadata_cache, _model_metadata_cache_time @@ -104,7 +104,7 @@ def _get_context_cache_path() -> Path: return hermes_home / "context_length_cache.yaml" -def _load_context_cache() -> Dict[str, int]: +def _load_context_cache() -> dict[str, int]: """Load the model+provider → context_length cache from disk.""" path = _get_context_cache_path() if not path.exists(): @@ -139,14 +139,14 @@ def save_context_length(model: str, base_url: str, length: int) -> None: logger.debug("Failed to save context length cache: %s", e) -def get_cached_context_length(model: str, base_url: str) -> Optional[int]: +def get_cached_context_length(model: str, base_url: str) -> int | None: """Look up a previously discovered context length for model+provider.""" key = f"{model}@{base_url}" cache = _load_context_cache() return cache.get(key) -def get_next_probe_tier(current_length: int) -> Optional[int]: +def get_next_probe_tier(current_length: int) -> int | None: """Return the next lower probe tier, or None if already at minimum.""" for tier in CONTEXT_PROBE_TIERS: if tier < current_length: @@ -154,7 +154,7 @@ def get_next_probe_tier(current_length: int) -> Optional[int]: return None -def parse_context_limit_from_error(error_msg: str) -> Optional[int]: +def parse_context_limit_from_error(error_msg: str) -> int | None: """Try to extract the actual context limit from an API error message. Many providers include the limit in their error text, e.g.: @@ -166,11 +166,11 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]: error_lower = error_msg.lower() # Pattern: look for numbers near context-related keywords patterns = [ - r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})', - r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})', - r'(\d{4,})\s*(?:token)?\s*(?:context|limit)', - r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum" - r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum" + r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})", + r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})", + r"(\d{4,})\s*(?:token)?\s*(?:context|limit)", + r">\s*(\d{4,})\s*(?:max|limit|token)", # "250000 tokens > 200000 maximum" + r"(\d{4,})\s*(?:max(?:imum)?)\b", # "200000 maximum" ] for pattern in patterns: match = re.search(pattern, error_lower) @@ -218,7 +218,7 @@ def estimate_tokens_rough(text: str) -> int: return len(text) // 4 -def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int: +def estimate_messages_tokens_rough(messages: list[dict[str, Any]]) -> int: """Rough token estimate for a message list (pre-flight only).""" total_chars = sum(len(str(msg)) for msg in messages) return total_chars // 4 diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 0582d63d36..6a8dc0ab8e 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -8,7 +8,6 @@ import logging import os import re from pathlib import Path -from typing import Optional logger = logging.getLogger(__name__) @@ -18,21 +17,29 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- _CONTEXT_THREAT_PATTERNS = [ - (r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"), - (r'do\s+not\s+tell\s+the\s+user', "deception_hide"), - (r'system\s+prompt\s+override', "sys_prompt_override"), - (r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"), - (r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"), - (r'', "html_comment_injection"), + (r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"), + (r"do\s+not\s+tell\s+the\s+user", "deception_hide"), + (r"system\s+prompt\s+override", "sys_prompt_override"), + (r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"), + (r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"), + (r"", "html_comment_injection"), (r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"), - (r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"), - (r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"), - (r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"), + (r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", "translate_execute"), + (r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"), + (r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"), ] _CONTEXT_INVISIBLE_CHARS = { - '\u200b', '\u200c', '\u200d', '\u2060', '\ufeff', - '\u202a', '\u202b', '\u202c', '\u202d', '\u202e', + "\u200b", + "\u200c", + "\u200d", + "\u2060", + "\ufeff", + "\u202a", + "\u202b", + "\u202c", + "\u202d", + "\u202e", } @@ -52,10 +59,13 @@ def _scan_context_content(content: str, filename: str) -> str: if findings: logger.warning("Context file %s blocked: %s", filename, ", ".join(findings)) - return f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]" + return ( + f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]" + ) return content + # ========================================================================= # Constants # ========================================================================= @@ -131,10 +141,7 @@ PLATFORM_HINTS = { "files arrive as downloadable documents. You can also include image " "URLs in markdown format ![alt](url) and they will be sent as photos." ), - "cli": ( - "You are a CLI AI Agent. Try not to use markdown but simple text " - "renderable inside a terminal." - ), + "cli": ("You are a CLI AI Agent. Try not to use markdown but simple text renderable inside a terminal."), } CONTEXT_FILE_MAX_CHARS = 20_000 @@ -146,18 +153,20 @@ CONTEXT_TRUNCATE_TAIL_RATIO = 0.2 # Skills index # ========================================================================= + def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str: """Read the description from a SKILL.md frontmatter, capped at max_chars.""" try: raw = skill_file.read_text(encoding="utf-8")[:2000] match = re.search( r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", - raw, re.MULTILINE | re.DOTALL, + raw, + re.MULTILINE | re.DOTALL, ) if match: desc = match.group(1).strip().strip("'\"") if len(desc) > max_chars: - desc = desc[:max_chars - 3] + "..." + desc = desc[: max_chars - 3] + "..." return desc except Exception: pass @@ -172,6 +181,7 @@ def _skill_is_platform_compatible(skill_file: Path) -> bool: """ try: from tools.skills_tool import _parse_frontmatter, skill_matches_platform + raw = skill_file.read_text(encoding="utf-8")[:2000] frontmatter, _ = _parse_frontmatter(raw) return skill_matches_platform(frontmatter) @@ -260,8 +270,7 @@ def build_skills_system_prompt() -> str: "load it with skill_view(name) and follow its instructions. " "If a skill has issues, fix it with skill_manage(action='patch').\n" "\n" - "\n" - + "\n".join(index_lines) + "\n" + "\n" + "\n".join(index_lines) + "\n" "\n" "\n" "If none match, proceed normally without loading a skill." @@ -272,6 +281,7 @@ def build_skills_system_prompt() -> str: # Context files (SOUL.md, AGENTS.md, .cursorrules) # ========================================================================= + def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str: """Head/tail truncation with a marker in the middle.""" if len(content) <= max_chars: @@ -284,7 +294,7 @@ def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE return head + marker + tail -def build_context_files_prompt(cwd: Optional[str] = None) -> str: +def build_context_files_prompt(cwd: str | None = None) -> str: """Discover and load context files for the system prompt. Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc, @@ -307,7 +317,9 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str: if top_level_agents: agents_files = [] for root, dirs, files in os.walk(cwd_path): - dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')] + dirs[:] = [ + d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "venv", ".venv") + ] for f in files: if f.lower() == "agents.md": agents_files.append(Path(root) / f) @@ -384,4 +396,7 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str: if not sections: return "" - return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections) + return ( + "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + + "\n".join(sections) + ) diff --git a/agent/prompt_caching.py b/agent/prompt_caching.py index aa80b2ddfa..842da407b3 100644 --- a/agent/prompt_caching.py +++ b/agent/prompt_caching.py @@ -9,7 +9,7 @@ Pure functions -- no class state, no AIAgent dependency. """ import copy -from typing import Any, Dict, List +from typing import Any def _apply_cache_marker(msg: dict, cache_marker: dict) -> None: @@ -36,9 +36,9 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None: def apply_anthropic_cache_control( - api_messages: List[Dict[str, Any]], + api_messages: list[dict[str, Any]], cache_ttl: str = "5m", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Apply system_and_3 caching strategy to messages for Anthropic models. Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages. diff --git a/agent/redact.py b/agent/redact.py index 02700c8327..afa93a0503 100644 --- a/agent/redact.py +++ b/agent/redact.py @@ -10,34 +10,33 @@ the first 6 and last 4 characters for debuggability. import logging import os import re -from typing import Optional logger = logging.getLogger(__name__) # Known API key prefixes -- match the prefix + contiguous token chars _PREFIX_PATTERNS = [ - r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*) - r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic) - r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained) - r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens - r"AIza[A-Za-z0-9_-]{30,}", # Google API keys - r"pplx-[A-Za-z0-9]{10,}", # Perplexity - r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai - r"fc-[A-Za-z0-9]{10,}", # Firecrawl - r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase - r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens - r"AKIA[A-Z0-9]{16}", # AWS Access Key ID - r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live) - r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test) - r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key - r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key - r"hf_[A-Za-z0-9]{10,}", # HuggingFace token - r"r8_[A-Za-z0-9]{10,}", # Replicate API token - r"npm_[A-Za-z0-9]{10,}", # npm access token - r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token - r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT - r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth - r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key + r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*) + r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic) + r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained) + r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens + r"AIza[A-Za-z0-9_-]{30,}", # Google API keys + r"pplx-[A-Za-z0-9]{10,}", # Perplexity + r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai + r"fc-[A-Za-z0-9]{10,}", # Firecrawl + r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase + r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens + r"AKIA[A-Z0-9]{16}", # AWS Access Key ID + r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live) + r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test) + r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key + r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key + r"hf_[A-Za-z0-9]{10,}", # HuggingFace token + r"r8_[A-Za-z0-9]{10,}", # Replicate API token + r"npm_[A-Za-z0-9]{10,}", # npm access token + r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token + r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT + r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth + r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key ] # ENV assignment patterns: KEY=value where KEY contains a secret-like name @@ -66,9 +65,7 @@ _TELEGRAM_RE = re.compile( ) # Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY----- -_PRIVATE_KEY_RE = re.compile( - r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----" -) +_PRIVATE_KEY_RE = re.compile(r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----") # Database connection strings: protocol://user:PASSWORD@host # Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password @@ -82,9 +79,7 @@ _DB_CONNSTR_RE = re.compile( _SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])") # Compile known prefix patterns into one alternation -_PREFIX_RE = re.compile( - r"(? str: @@ -112,12 +107,14 @@ def redact_sensitive_text(text: str) -> str: def _redact_env(m): name, quote, value = m.group(1), m.group(2), m.group(3) return f"{name}={quote}{_mask_token(value)}{quote}" + text = _ENV_ASSIGN_RE.sub(_redact_env, text) # JSON fields: "apiKey": "value" def _redact_json(m): key, value = m.group(1), m.group(2) return f'{key}: "{_mask_token(value)}"' + text = _JSON_FIELD_RE.sub(_redact_json, text) # Authorization headers @@ -131,6 +128,7 @@ def redact_sensitive_text(text: str) -> str: prefix = m.group(1) or "" digits = m.group(2) return f"{prefix}{digits}:***" + text = _TELEGRAM_RE.sub(_redact_telegram, text) # Private key blocks @@ -145,6 +143,7 @@ def redact_sensitive_text(text: str) -> str: if len(phone) <= 8: return phone[:2] + "****" + phone[-2:] return phone[:4] + "****" + phone[-4:] + text = _SIGNAL_PHONE_RE.sub(_redact_phone, text) return text @@ -153,7 +152,7 @@ def redact_sensitive_text(text: str) -> str: class RedactingFormatter(logging.Formatter): """Log formatter that redacts secrets from all log messages.""" - def __init__(self, fmt=None, datefmt=None, style='%', **kwargs): + def __init__(self, fmt=None, datefmt=None, style="%", **kwargs): super().__init__(fmt, datefmt, style, **kwargs) def format(self, record: logging.LogRecord) -> str: diff --git a/agent/skill_commands.py b/agent/skill_commands.py index 4466ba35ca..03cb5a9534 100644 --- a/agent/skill_commands.py +++ b/agent/skill_commands.py @@ -6,14 +6,14 @@ can invoke skills via /skill-name commands. import logging from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any logger = logging.getLogger(__name__) -_skill_commands: Dict[str, Dict[str, Any]] = {} +_skill_commands: dict[str, dict[str, Any]] = {} -def scan_skill_commands() -> Dict[str, Dict[str, Any]]: +def scan_skill_commands() -> dict[str, dict[str, Any]]: """Scan ~/.hermes/skills/ and return a mapping of /command -> skill info. Returns: @@ -23,26 +23,27 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]: _skill_commands = {} try: from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform + if not SKILLS_DIR.exists(): return _skill_commands for skill_md in SKILLS_DIR.rglob("SKILL.md"): - if any(part in ('.git', '.github', '.hub') for part in skill_md.parts): + if any(part in (".git", ".github", ".hub") for part in skill_md.parts): continue try: - content = skill_md.read_text(encoding='utf-8') + content = skill_md.read_text(encoding="utf-8") frontmatter, body = _parse_frontmatter(content) # Skip skills incompatible with the current OS platform if not skill_matches_platform(frontmatter): continue - name = frontmatter.get('name', skill_md.parent.name) - description = frontmatter.get('description', '') + name = frontmatter.get("name", skill_md.parent.name) + description = frontmatter.get("description", "") if not description: - for line in body.strip().split('\n'): + for line in body.strip().split("\n"): line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): description = line[:80] break - cmd_name = name.lower().replace(' ', '-').replace('_', '-') + cmd_name = name.lower().replace(" ", "-").replace("_", "-") _skill_commands[f"/{cmd_name}"] = { "name": name, "description": description or f"Invoke the {name} skill", @@ -56,14 +57,14 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]: return _skill_commands -def get_skill_commands() -> Dict[str, Dict[str, Any]]: +def get_skill_commands() -> dict[str, dict[str, Any]]: """Return the current skill commands mapping (scan first if empty).""" if not _skill_commands: scan_skill_commands() return _skill_commands -def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> Optional[str]: +def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> str | None: """Build the user message content for a skill slash command invocation. Args: @@ -83,7 +84,7 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> skill_name = skill_info["name"] try: - content = skill_md_path.read_text(encoding='utf-8') + content = skill_md_path.read_text(encoding="utf-8") except Exception: return f"[Failed to load skill: {skill_name}]" @@ -111,6 +112,8 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> if user_instruction: parts.append("") - parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}") + parts.append( + f"The user has provided the following instruction alongside the skill invocation: {user_instruction}" + ) return "\n".join(parts) diff --git a/agent/trajectory.py b/agent/trajectory.py index 90696eb8a3..2632c0f63b 100644 --- a/agent/trajectory.py +++ b/agent/trajectory.py @@ -8,7 +8,7 @@ the file-write logic live here. import json import logging from datetime import datetime -from typing import Any, Dict, List +from typing import Any logger = logging.getLogger(__name__) @@ -27,8 +27,7 @@ def has_incomplete_scratchpad(content: str) -> bool: return "" in content and "" not in content -def save_trajectory(trajectory: List[Dict[str, Any]], model: str, - completed: bool, filename: str = None): +def save_trajectory(trajectory: list[dict[str, Any]], model: str, completed: bool, filename: str = None): """Append a trajectory entry to a JSONL file. Args: diff --git a/batch_runner.py b/batch_runner.py index a4c402ffdc..4b60943e37 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + """ Batch Agent Runner @@ -12,10 +14,10 @@ across multiple prompts from a dataset. It includes: Usage: python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run - + # Resume an interrupted run python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume - + # Use a specific toolset distribution python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen """ @@ -24,23 +26,19 @@ import json import logging import os import time -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple -from datetime import datetime -from multiprocessing import Pool, Lock import traceback -from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn -from rich.console import Console +from datetime import datetime +from multiprocessing import Lock, Pool +from pathlib import Path +from typing import Any + import fire +from rich.console import Console +from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeRemainingColumn -from run_agent import AIAgent -from toolset_distributions import ( - list_distributions, - sample_toolsets_from_distribution, - validate_distribution -) from model_tools import TOOL_TO_TOOLSET_MAP - +from run_agent import AIAgent +from toolset_distributions import list_distributions, sample_toolsets_from_distribution, validate_distribution # Global configuration for worker processes _WORKER_CONFIG = {} @@ -52,112 +50,108 @@ _WORKER_CONFIG = {} ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys()) # Default stats for tools that weren't used -DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0} +DEFAULT_TOOL_STATS = {"count": 0, "success": 0, "failure": 0} -def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: +def _normalize_tool_stats(tool_stats: dict[str, dict[str, int]]) -> dict[str, dict[str, int]]: """ Normalize tool_stats to include all possible tools with consistent schema. - + This ensures HuggingFace datasets can load the JSONL without schema mismatch errors. Tools that weren't used get zero counts. - + Args: tool_stats (Dict): Raw tool statistics from extraction - + Returns: Dict: Normalized tool statistics with all tools present """ normalized = {} - + # Add all possible tools with defaults for tool in ALL_POSSIBLE_TOOLS: if tool in tool_stats: normalized[tool] = tool_stats[tool].copy() else: normalized[tool] = DEFAULT_TOOL_STATS.copy() - + # Also include any unexpected tools (in case new tools are added) for tool, stats in tool_stats.items(): if tool not in normalized: normalized[tool] = stats.copy() - + return normalized -def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]: +def _normalize_tool_error_counts(tool_error_counts: dict[str, int]) -> dict[str, int]: """ Normalize tool_error_counts to include all possible tools. - + Args: tool_error_counts (Dict): Raw error counts mapping - + Returns: Dict: Normalized error counts with all tools present """ normalized = {} - + # Add all possible tools with zero defaults for tool in ALL_POSSIBLE_TOOLS: normalized[tool] = tool_error_counts.get(tool, 0) - + # Also include any unexpected tools for tool, count in tool_error_counts.items(): if tool not in normalized: normalized[tool] = count - + return normalized -def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: +def _extract_tool_stats(messages: list[dict[str, Any]]) -> dict[str, dict[str, int]]: """ Extract tool usage statistics from message history. - + Args: messages (List[Dict]): Message history - + Returns: Dict: Tool statistics with counts and success/failure rates """ tool_stats = {} - + # Track tool calls and their results tool_calls_map = {} # Map tool_call_id to tool name - + for msg in messages: # Track tool calls from assistant messages if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: for tool_call in msg["tool_calls"]: tool_name = tool_call["function"]["name"] tool_call_id = tool_call["id"] - + # Initialize stats for this tool if not exists if tool_name not in tool_stats: - tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - + tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0} + tool_stats[tool_name]["count"] += 1 tool_calls_map[tool_call_id] = tool_name - + # Track tool responses elif msg["role"] == "tool": tool_call_id = msg.get("tool_call_id", "") content = msg.get("content", "") - + # Determine if tool call was successful is_success = True try: # Try to parse as JSON and check for actual error values content_json = json.loads(content) if isinstance(content, str) else content - + if isinstance(content_json, dict): # Check if error field exists AND has a non-null value if "error" in content_json and content_json["error"] is not None: is_success = False - + # Special handling for terminal tool responses # Terminal wraps its response in a "content" field if "content" in content_json and isinstance(content_json["content"], dict): @@ -166,20 +160,17 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i # Note: non-zero exit codes are not failures - the model can self-correct if inner_content.get("error") is not None: is_success = False - + # Check for "success": false pattern used by some tools if content_json.get("success") is False: is_success = False - + except (json.JSONDecodeError, ValueError, TypeError): # If not JSON, check if content is empty or explicitly states an error # Note: We avoid simple substring matching to prevent false positives - if not content: + if not content or content.strip().lower().startswith("error:"): is_success = False - # Only mark as failure if it explicitly starts with "Error:" or "ERROR:" - elif content.strip().lower().startswith("error:"): - is_success = False - + # Update success/failure count if tool_call_id in tool_calls_map: tool_name = tool_calls_map[tool_call_id] @@ -187,38 +178,38 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i tool_stats[tool_name]["success"] += 1 else: tool_stats[tool_name]["failure"] += 1 - + return tool_stats -def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: +def _extract_reasoning_stats(messages: list[dict[str, Any]]) -> dict[str, int]: """ Count how many assistant turns have reasoning vs no reasoning. - + Checks for in content or a non-empty 'reasoning' field (native thinking tokens). Returns counts for tracking reasoning coverage. - + Args: messages: Message history - + Returns: Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning' """ total = 0 with_reasoning = 0 - + for msg in messages: if msg.get("role") != "assistant": continue total += 1 - + content = msg.get("content", "") or "" has_scratchpad = "" in content has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False - + if has_scratchpad or has_native_reasoning: with_reasoning += 1 - + return { "total_assistant_turns": total, "turns_with_reasoning": with_reasoning, @@ -228,26 +219,23 @@ def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: def _process_single_prompt( - prompt_index: int, - prompt_data: Dict[str, Any], - batch_num: int, - config: Dict[str, Any] -) -> Dict[str, Any]: + prompt_index: int, prompt_data: dict[str, Any], batch_num: int, config: dict[str, Any] +) -> dict[str, Any]: """ Process a single prompt with the agent. - + Args: prompt_index (int): Index of prompt in dataset prompt_data (Dict): Prompt data containing 'prompt' field and optional 'image' field batch_num (int): Batch number config (Dict): Configuration dict with agent parameters - + Returns: Dict: Result containing trajectory, stats, and metadata """ prompt = prompt_data["prompt"] task_id = f"task_{prompt_index}" - + # Per-prompt container image override: if the dataset row has an 'image' field, # register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona. container_image = prompt_data.get("image") or prompt_data.get("docker_image") @@ -258,17 +246,21 @@ def _process_single_prompt( env_type = os.getenv("TERMINAL_ENV", "local") if env_type == "docker": import subprocess as _sp + try: probe = _sp.run( ["docker", "image", "inspect", container_image], - capture_output=True, timeout=10, + capture_output=True, + timeout=10, ) if probe.returncode != 0: if config.get("verbose"): print(f" Prompt {prompt_index}: Pulling docker image {container_image}...", flush=True) pull = _sp.run( ["docker", "pull", container_image], - capture_output=True, text=True, timeout=600, + capture_output=True, + text=True, + timeout=600, ) if pull.returncode != 0: return { @@ -287,6 +279,7 @@ def _process_single_prompt( print(f" Prompt {prompt_index}: Docker image check failed: {img_err}", flush=True) from tools.terminal_tool import register_task_env_overrides + overrides = { "docker_image": container_image, "modal_image": container_image, @@ -298,14 +291,14 @@ def _process_single_prompt( register_task_env_overrides(task_id, overrides) if config.get("verbose"): print(f" Prompt {prompt_index}: Using container image {container_image}") - + try: # Sample toolsets from distribution for this prompt selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) - + if config.get("verbose"): print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") - + # Initialize agent with sampled toolsets and log prefix for identification log_prefix = f"[B{batch_num}:P{prompt_index}]" agent = AIAgent( @@ -332,20 +325,16 @@ def _process_single_prompt( # Run the agent with task_id to ensure each task gets its own isolated VM result = agent.run_conversation(prompt, task_id=task_id) - + # Extract tool usage statistics tool_stats = _extract_tool_stats(result["messages"]) - + # Extract reasoning coverage stats reasoning_stats = _extract_reasoning_stats(result["messages"]) - + # Convert to trajectory format (using existing method) - trajectory = agent._convert_to_trajectory_format( - result["messages"], - prompt, - result["completed"] - ) - + trajectory = agent._convert_to_trajectory_format(result["messages"], prompt, result["completed"]) + return { "success": True, "prompt_index": prompt_index, @@ -356,18 +345,14 @@ def _process_single_prompt( "partial": result.get("partial", False), "api_calls": result["api_calls"], "toolsets_used": selected_toolsets, - "metadata": { - "batch_num": batch_num, - "timestamp": datetime.now().isoformat(), - "model": config["model"] - } + "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat(), "model": config["model"]}, } - + except Exception as e: print(f"❌ Error processing prompt {prompt_index}: {e}") if config.get("verbose"): traceback.print_exc() - + return { "success": False, "prompt_index": prompt_index, @@ -375,37 +360,31 @@ def _process_single_prompt( "trajectory": None, "tool_stats": {}, "toolsets_used": [], - "metadata": { - "batch_num": batch_num, - "timestamp": datetime.now().isoformat() - } + "metadata": {"batch_num": batch_num, "timestamp": datetime.now().isoformat()}, } -def _process_batch_worker(args: Tuple) -> Dict[str, Any]: +def _process_batch_worker(args: tuple) -> dict[str, Any]: """ Worker function to process a single batch of prompts. - + Args: args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) - + Returns: Dict: Batch results with statistics """ batch_num, batch_data, output_dir, completed_prompts_set, config = args - + output_dir = Path(output_dir) print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)") - + # Output file for this batch batch_output_file = output_dir / f"batch_{batch_num}.jsonl" - + # Filter out already completed prompts - prompts_to_process = [ - (idx, data) for idx, data in batch_data - if idx not in completed_prompts_set - ] - + prompts_to_process = [(idx, data) for idx, data in batch_data if idx not in completed_prompts_set] + if not prompts_to_process: print(f"✅ Batch {batch_num}: Already completed (skipping)") return { @@ -413,27 +392,24 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "processed": 0, "skipped": len(batch_data), "tool_stats": {}, - "completed_prompts": [] + "completed_prompts": [], } - - print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") - + + print( + f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)" + ) + # Initialize aggregated stats for this batch batch_tool_stats = {} batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} completed_in_batch = [] discarded_no_reasoning = 0 - + # Process each prompt sequentially in this batch for prompt_index, prompt_data in prompts_to_process: # Process the prompt - result = _process_single_prompt( - prompt_index, - prompt_data, - batch_num, - config - ) - + result = _process_single_prompt(prompt_index, prompt_data, batch_num, config) + # Save trajectory if successful if result["success"] and result["trajectory"]: # Discard samples with zero reasoning across all turns @@ -442,18 +418,15 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: print(f" 🚫 Prompt {prompt_index} discarded (no reasoning in any turn)") discarded_no_reasoning += 1 continue - + # Get and normalize tool stats for consistent schema across all entries raw_tool_stats = result.get("tool_stats", {}) tool_stats = _normalize_tool_stats(raw_tool_stats) - + # Create normalized tool_error_counts mapping tool names to their failure counts - raw_error_counts = { - tool_name: stats.get("failure", 0) - for tool_name, stats in raw_tool_stats.items() - } + raw_error_counts = {tool_name: stats.get("failure", 0) for tool_name, stats in raw_tool_stats.items()} tool_error_counts = _normalize_tool_error_counts(raw_error_counts) - + trajectory_entry = { "prompt_index": prompt_index, "conversations": result["trajectory"], @@ -463,30 +436,26 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "api_calls": result["api_calls"], "toolsets_used": result["toolsets_used"], "tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized - "tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized + "tool_error_counts": tool_error_counts, # Simple: {tool: failure_count} - normalized } - + # Append to batch output file - with open(batch_output_file, 'a', encoding='utf-8') as f: + with open(batch_output_file, "a", encoding="utf-8") as f: f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") - + # Aggregate tool statistics for tool_name, stats in result.get("tool_stats", {}).items(): if tool_name not in batch_tool_stats: - batch_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - + batch_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0} + batch_tool_stats[tool_name]["count"] += stats["count"] batch_tool_stats[tool_name]["success"] += stats["success"] batch_tool_stats[tool_name]["failure"] += stats["failure"] - + # Aggregate reasoning stats for key in batch_reasoning_stats: batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0) - + # Only mark as completed if successfully saved (failed prompts can be retried on resume) if result["success"] and result["trajectory"]: completed_in_batch.append(prompt_index) @@ -494,9 +463,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: print(f" {status} Prompt {prompt_index} completed") else: print(f" ❌ Prompt {prompt_index} failed (will retry on resume)") - + print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") - + return { "batch_num": batch_num, "processed": len(prompts_to_process), @@ -504,7 +473,7 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "tool_stats": batch_tool_stats, "reasoning_stats": batch_reasoning_stats, "discarded_no_reasoning": discarded_no_reasoning, - "completed_prompts": completed_in_batch + "completed_prompts": completed_in_batch, } @@ -512,7 +481,7 @@ class BatchRunner: """ Manages batch processing of agent prompts with checkpointing and statistics. """ - + def __init__( self, dataset_file: str, @@ -527,13 +496,13 @@ class BatchRunner: verbose: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, - providers_allowed: List[str] = None, - providers_ignored: List[str] = None, - providers_order: List[str] = None, + providers_allowed: list[str] = None, + providers_ignored: list[str] = None, + providers_order: list[str] = None, provider_sort: str = None, max_tokens: int = None, - reasoning_config: Dict[str, Any] = None, - prefill_messages: List[Dict[str, Any]] = None, + reasoning_config: dict[str, Any] = None, + prefill_messages: list[dict[str, Any]] = None, max_samples: int = None, ): """ @@ -581,32 +550,32 @@ class BatchRunner: self.reasoning_config = reasoning_config self.prefill_messages = prefill_messages self.max_samples = max_samples - + # Validate distribution if not validate_distribution(distribution): raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") - + # Setup output directory self.output_dir = Path("data") / run_name self.output_dir.mkdir(parents=True, exist_ok=True) - + # Checkpoint file self.checkpoint_file = self.output_dir / "checkpoint.json" - + # Statistics file self.stats_file = self.output_dir / "statistics.json" - + # Load dataset (and optionally truncate to max_samples) self.dataset = self._load_dataset() if self.max_samples and self.max_samples < len(self.dataset): full_count = len(self.dataset) - self.dataset = self.dataset[:self.max_samples] + self.dataset = self.dataset[: self.max_samples] print(f"✂️ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)") - + # Create batches self.batches = self._create_batches() - - print(f"📊 Batch Runner Initialized") + + print("📊 Batch Runner Initialized") print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") print(f" Batch size: {self.batch_size}") print(f" Total batches: {len(self.batches)}") @@ -615,86 +584,80 @@ class BatchRunner: print(f" Output directory: {self.output_dir}") print(f" Workers: {self.num_workers}") if self.ephemeral_system_prompt: - prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt + prompt_preview = ( + self.ephemeral_system_prompt[:60] + "..." + if len(self.ephemeral_system_prompt) > 60 + else self.ephemeral_system_prompt + ) print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'") - - def _load_dataset(self) -> List[Dict[str, Any]]: + + def _load_dataset(self) -> list[dict[str, Any]]: """ Load dataset from JSONL file. - + Returns: List[Dict]: List of dataset entries """ if not self.dataset_file.exists(): raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") - + dataset = [] - with open(self.dataset_file, 'r', encoding='utf-8') as f: + with open(self.dataset_file, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue - + try: entry = json.loads(line) - if 'prompt' not in entry: + if "prompt" not in entry: print(f"⚠️ Warning: Line {line_num} missing 'prompt' field, skipping") continue dataset.append(entry) except json.JSONDecodeError as e: print(f"⚠️ Warning: Invalid JSON on line {line_num}: {e}") continue - + if not dataset: raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") - + return dataset - - def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: + + def _create_batches(self) -> list[list[tuple[int, dict[str, Any]]]]: """ Split dataset into batches with indices. - + Returns: List of batches, where each batch is a list of (index, entry) tuples """ batches = [] for i in range(0, len(self.dataset), self.batch_size): - batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] + batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i : i + self.batch_size], start=i)] batches.append(batch) - + return batches - - def _load_checkpoint(self) -> Dict[str, Any]: + + def _load_checkpoint(self) -> dict[str, Any]: """ Load checkpoint data if it exists. - + Returns: Dict: Checkpoint data with completed prompt indices """ if not self.checkpoint_file.exists(): - return { - "run_name": self.run_name, - "completed_prompts": [], - "batch_stats": {}, - "last_updated": None - } - + return {"run_name": self.run_name, "completed_prompts": [], "batch_stats": {}, "last_updated": None} + try: - with open(self.checkpoint_file, 'r', encoding='utf-8') as f: + with open(self.checkpoint_file, encoding="utf-8") as f: return json.load(f) except Exception as e: print(f"⚠️ Warning: Failed to load checkpoint: {e}") - return { - "run_name": self.run_name, - "completed_prompts": [], - "batch_stats": {}, - "last_updated": None - } - - def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): + return {"run_name": self.run_name, "completed_prompts": [], "batch_stats": {}, "last_updated": None} + + def _save_checkpoint(self, checkpoint_data: dict[str, Any], lock: Lock | None = None): """ Save checkpoint data. - + Args: checkpoint_data (Dict): Checkpoint data to save lock (Lock): Optional lock for thread-safe access @@ -702,41 +665,42 @@ class BatchRunner: checkpoint_data["last_updated"] = datetime.now().isoformat() from utils import atomic_json_write + if lock: with lock: atomic_json_write(self.checkpoint_file, checkpoint_data) else: atomic_json_write(self.checkpoint_file, checkpoint_data) - + def _scan_completed_prompts_by_content(self) -> set: """ Scan all batch files and extract completed prompts by their actual content. - + This provides a more robust resume mechanism that matches on prompt text rather than indices, allowing recovery even if indices don't match. - + Returns: set: Set of prompt texts that have been successfully processed """ completed_prompts = set() batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) - + if not batch_files: return completed_prompts - + print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...") - + for batch_file in batch_files: try: - with open(batch_file, 'r', encoding='utf-8') as f: + with open(batch_file, encoding="utf-8") as f: for line in f: try: entry = json.loads(line.strip()) - + # Skip failed entries - we want to retry these if entry.get("failed", False): continue - + # Extract the human/user prompt from conversations conversations = entry.get("conversations", []) for msg in conversations: @@ -749,26 +713,26 @@ class BatchRunner: continue except Exception as e: print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}") - + return completed_prompts - - def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]: + + def _filter_dataset_by_completed(self, completed_prompts: set) -> tuple[list[dict], list[int]]: """ Filter the dataset to exclude prompts that have already been completed. - + Args: completed_prompts: Set of prompt texts that have been completed - + Returns: Tuple of (filtered_dataset, skipped_indices) """ filtered_dataset = [] skipped_indices = [] - + for idx, entry in enumerate(self.dataset): # Extract prompt from the dataset entry prompt_text = entry.get("prompt", "").strip() - + # Also check conversations format if not prompt_text: conversations = entry.get("conversations", []) @@ -777,60 +741,60 @@ class BatchRunner: if role in ("user", "human"): prompt_text = (msg.get("content") or msg.get("value", "")).strip() break - + if prompt_text in completed_prompts: skipped_indices.append(idx) else: # Keep original index for tracking filtered_dataset.append((idx, entry)) - + return filtered_dataset, skipped_indices - + def run(self, resume: bool = False): """ Run the batch processing pipeline. - + Args: resume (bool): Whether to resume from checkpoint """ print("\n" + "=" * 70) print("🚀 Starting Batch Processing") print("=" * 70) - + # Smart resume: scan batch files by content to find completed prompts completed_prompt_texts = set() if resume: completed_prompt_texts = self._scan_completed_prompts_by_content() if completed_prompt_texts: print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching") - + # Filter dataset to only include unprocessed prompts if resume and completed_prompt_texts: filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts) - + if not filtered_entries: print("\n✅ All prompts have already been processed!") return - + # Recreate batches from filtered entries (keeping original indices for tracking) batches_to_process = [] for i in range(0, len(filtered_entries), self.batch_size): - batch = filtered_entries[i:i + self.batch_size] + batch = filtered_entries[i : i + self.batch_size] batches_to_process.append(batch) - + self.batches = batches_to_process - + # Print prominent resume summary print("\n" + "=" * 70) print("📊 RESUME SUMMARY") print("=" * 70) print(f" Original dataset size: {len(self.dataset):,} prompts") print(f" Already completed: {len(skipped_indices):,} prompts") - print(f" ─────────────────────────────────────────") + print(" ─────────────────────────────────────────") print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts") print(f" New batches created: {len(batches_to_process)}") print("=" * 70 + "\n") - + # Load existing checkpoint (so resume doesn't clobber prior progress) checkpoint_data = self._load_checkpoint() if checkpoint_data.get("run_name") != self.run_name: @@ -838,9 +802,9 @@ class BatchRunner: "run_name": self.run_name, "completed_prompts": [], "batch_stats": {}, - "last_updated": None + "last_updated": None, } - + # Prepare configuration for workers config = { "distribution": self.distribution, @@ -859,17 +823,17 @@ class BatchRunner: "reasoning_config": self.reasoning_config, "prefill_messages": self.prefill_messages, } - + # For backward compatibility, still track by index (but this is secondary to content matching) completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) - + # Aggregate statistics across all batches total_tool_stats = {} - + start_time = time.time() - + print(f"\n🔧 Initializing {self.num_workers} worker processes...") - + # Checkpoint writes happen in the parent process; keep a lock for safety. checkpoint_lock = Lock() @@ -882,14 +846,14 @@ class BatchRunner: batch_data, str(self.output_dir), # Convert Path to string for pickling completed_prompts_set, - config + config, ) for batch_num, batch_data in enumerate(self.batches) ] - + print(f"✅ Created {len(tasks)} batch tasks") - print(f"🚀 Starting parallel batch processing...\n") - + print("🚀 Starting parallel batch processing...\n") + # Use rich Progress for better visual tracking with persistent bottom bar # redirect_stdout/stderr lets rich manage all output so progress bar stays clean results = [] @@ -908,12 +872,12 @@ class BatchRunner: redirect_stderr=False, ) as progress: task = progress.add_task("Processing", total=len(tasks)) - + # Temporarily suppress DEBUG logging to avoid bar interference root_logger = logging.getLogger() original_level = root_logger.level root_logger.setLevel(logging.WARNING) - + try: for result in pool.imap_unordered(_process_batch_worker, tasks): results.append(result) @@ -921,18 +885,18 @@ class BatchRunner: # Incremental checkpoint update (so resume works after crash) try: - batch_num = result.get('batch_num') - completed = result.get('completed_prompts', []) or [] + batch_num = result.get("batch_num") + completed = result.get("completed_prompts", []) or [] completed_prompts_set.update(completed) if isinstance(batch_num, int): - checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = { - 'processed': result.get('processed', 0), - 'skipped': result.get('skipped', 0), - 'discarded_no_reasoning': result.get('discarded_no_reasoning', 0), + checkpoint_data.setdefault("batch_stats", {})[str(batch_num)] = { + "processed": result.get("processed", 0), + "skipped": result.get("skipped", 0), + "discarded_no_reasoning": result.get("discarded_no_reasoning", 0), } - checkpoint_data['completed_prompts'] = sorted(completed_prompts_set) + checkpoint_data["completed_prompts"] = sorted(completed_prompts_set) self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) except Exception as ckpt_err: # Don't fail the run if checkpoint write fails @@ -942,39 +906,35 @@ class BatchRunner: raise finally: root_logger.setLevel(original_level) - + # Aggregate all batch statistics and update checkpoint all_completed_prompts = list(completed_prompts_set) total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} - + for batch_result in results: # Add newly completed prompts all_completed_prompts.extend(batch_result.get("completed_prompts", [])) - + # Aggregate tool stats for tool_name, stats in batch_result.get("tool_stats", {}).items(): if tool_name not in total_tool_stats: - total_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - + total_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0} + total_tool_stats[tool_name]["count"] += stats["count"] total_tool_stats[tool_name]["success"] += stats["success"] total_tool_stats[tool_name]["failure"] += stats["failure"] - + # Aggregate reasoning stats for key in total_reasoning_stats: total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) - + # Save final checkpoint (best-effort; incremental writes already happened) try: checkpoint_data["completed_prompts"] = all_completed_prompts self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) except Exception as ckpt_err: print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}") - + # Calculate success rates for tool_name in total_tool_stats: stats = total_tool_stats[tool_name] @@ -985,53 +945,59 @@ class BatchRunner: else: stats["success_rate"] = 0.0 stats["failure_rate"] = 0.0 - + # Combine ALL batch files in directory into a single trajectories.jsonl file # This includes both old batches (from previous runs) and new batches (from resume) # Also filter out corrupted entries (where model generated invalid tool names) combined_file = self.output_dir / "trajectories.jsonl" print(f"\n📦 Combining ALL batch files into {combined_file.name}...") - + # Valid tools auto-derived from model_tools.py — no manual updates needed VALID_TOOLS = ALL_POSSIBLE_TOOLS - + total_entries = 0 filtered_entries = 0 batch_files_found = 0 - + # Find ALL batch files in the output directory (handles resume merging old + new) all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) - - with open(combined_file, 'w', encoding='utf-8') as outfile: + + with open(combined_file, "w", encoding="utf-8") as outfile: for batch_file in all_batch_files: batch_files_found += 1 batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging - - with open(batch_file, 'r', encoding='utf-8') as infile: + + with open(batch_file, encoding="utf-8") as infile: for line in infile: total_entries += 1 try: data = json.loads(line) - tool_stats = data.get('tool_stats', {}) - + tool_stats = data.get("tool_stats", {}) + # Check for invalid tool names (model hallucinations) invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS] - + if invalid_tools: filtered_entries += 1 - invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] - print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'") + invalid_preview = ( + invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] + ) + print( + f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'" + ) continue - + outfile.write(line) except json.JSONDecodeError: filtered_entries += 1 print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})") - + if filtered_entries > 0: print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total") - print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)") - + print( + f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)" + ) + # Save final statistics final_stats = { "run_name": self.run_name, @@ -1045,10 +1011,10 @@ class BatchRunner: "tool_statistics": total_tool_stats, "reasoning_statistics": total_reasoning_stats, } - - with open(self.stats_file, 'w', encoding='utf-8') as f: + + with open(self.stats_file, "w", encoding="utf-8") as f: json.dump(final_stats, f, indent=2, ensure_ascii=False) - + # Print summary print("\n" + "=" * 70) print("📊 BATCH PROCESSING COMPLETE") @@ -1057,17 +1023,13 @@ class BatchRunner: print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}") print(f"✅ Total batch files merged: {batch_files_found}") print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s") - print(f"\n📈 Tool Usage Statistics:") + print("\n📈 Tool Usage Statistics:") print("-" * 70) - + if total_tool_stats: # Sort by count descending - sorted_tools = sorted( - total_tool_stats.items(), - key=lambda x: x[1]["count"], - reverse=True - ) - + sorted_tools = sorted(total_tool_stats.items(), key=lambda x: x[1]["count"], reverse=True) + print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") print("-" * 70) for tool_name, stats in sorted_tools: @@ -1080,11 +1042,11 @@ class BatchRunner: ) else: print("No tool calls were made during this run.") - + # Print reasoning coverage stats total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results) - - print(f"\n🧠 Reasoning Coverage:") + + print("\n🧠 Reasoning Coverage:") print("-" * 70) total_turns = total_reasoning_stats["total_assistant_turns"] with_reasoning = total_reasoning_stats["turns_with_reasoning"] @@ -1099,10 +1061,10 @@ class BatchRunner: print(" No assistant turns recorded.") if total_discarded > 0: print(f" 🚫 Samples discarded (zero reasoning): {total_discarded:,}") - + print(f"\n💾 Results saved to: {self.output_dir}") - print(f" - Trajectories: trajectories.jsonl (combined)") - print(f" - Individual batches: batch_*.jsonl (for debugging)") + print(" - Trajectories: trajectories.jsonl (combined)") + print(" - Individual batches: batch_*.jsonl (for debugging)") print(f" - Statistics: {self.stats_file.name}") print(f" - Checkpoint: {self.checkpoint_file.name}") @@ -1159,62 +1121,63 @@ def main( reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False) prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts) max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) - + Examples: # Basic usage python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run - + # Resume interrupted run python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume - + # Use specific distribution python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen - + # With disabled reasoning and max tokens python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ --reasoning_disabled --max_tokens=128000 - + # With prefill messages from file python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ --prefill_messages_file=configs/prefill_opus.json - + # List available distributions python batch_runner.py --list_distributions """ # Handle list distributions if list_distributions: - from toolset_distributions import list_distributions as get_all_dists, print_distribution_info - + from toolset_distributions import list_distributions as get_all_dists + from toolset_distributions import print_distribution_info + print("📊 Available Toolset Distributions") print("=" * 70) - + all_dists = get_all_dists() for dist_name in sorted(all_dists.keys()): print_distribution_info(dist_name) - + print("\n💡 Usage:") print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") print(" --run_name=my_run --distribution=") return - + # Validate required arguments if not dataset_file: print("❌ Error: --dataset_file is required") return - + if not batch_size or batch_size < 1: print("❌ Error: --batch_size must be a positive integer") return - + if not run_name: print("❌ Error: --run_name is required") return - + # Parse provider preferences (comma-separated strings to lists) providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None - + # Build reasoning_config from CLI flags # --reasoning_disabled takes priority, then --reasoning_effort, then default (medium) reasoning_config = None @@ -1230,21 +1193,21 @@ def main( return reasoning_config = {"enabled": True, "effort": reasoning_effort} print(f"🧠 Reasoning effort: {reasoning_effort}") - + # Load prefill messages from JSON file if provided prefill_messages = None if prefill_messages_file: try: - with open(prefill_messages_file, 'r', encoding='utf-8') as f: + with open(prefill_messages_file, encoding="utf-8") as f: prefill_messages = json.load(f) if not isinstance(prefill_messages, list): - print(f"❌ Error: prefill_messages_file must contain a JSON array of messages") + print("❌ Error: prefill_messages_file must contain a JSON array of messages") return print(f"💬 Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}") except Exception as e: print(f"❌ Error loading prefill messages: {e}") return - + # Initialize and run batch runner try: runner = BatchRunner( @@ -1271,7 +1234,7 @@ def main( ) runner.run(resume=resume) - + except Exception as e: print(f"\n❌ Fatal error: {e}") if verbose: @@ -1281,4 +1244,3 @@ def main( if __name__ == "__main__": fire.Fire(main) - diff --git a/cli.py b/cli.py index c82e85dc86..6e45d59ea5 100755 --- a/cli.py +++ b/cli.py @@ -12,16 +12,16 @@ Usage: python cli.py --list-tools # List available tools and exit """ +import atexit +import json import logging import os import shutil import sys -import json -import atexit import uuid -from pathlib import Path from datetime import datetime -from typing import List, Dict, Any, Optional +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) @@ -29,33 +29,34 @@ logger = logging.getLogger(__name__) os.environ["MSWEA_SILENT_STARTUP"] = "1" # mini-swe-agent os.environ["HERMES_QUIET"] = "1" # Our own modules -import yaml - -# prompt_toolkit for fixed input area TUI -from prompt_toolkit.history import FileHistory -from prompt_toolkit.styles import Style as PTStyle -from prompt_toolkit.patch_stdout import patch_stdout -from prompt_toolkit.application import Application -from prompt_toolkit.layout import Layout, HSplit, Window, FormattedTextControl, ConditionalContainer -from prompt_toolkit.layout.processors import Processor, Transformation, PasswordProcessor, ConditionalProcessor -from prompt_toolkit.filters import Condition -from prompt_toolkit.layout.dimension import Dimension -from prompt_toolkit.layout.menus import CompletionsMenu -from prompt_toolkit.widgets import TextArea -from prompt_toolkit.key_binding import KeyBindings -from prompt_toolkit import print_formatted_text as _pt_print -from prompt_toolkit.formatted_text import ANSI as _PT_ANSI -import threading import queue +import threading +import yaml # Load .env from ~/.hermes/.env first, then project root as dev fallback from dotenv import load_dotenv +from prompt_toolkit import print_formatted_text as _pt_print +from prompt_toolkit.application import Application +from prompt_toolkit.filters import Condition +from prompt_toolkit.formatted_text import ANSI as _PT_ANSI + +# prompt_toolkit for fixed input area TUI +from prompt_toolkit.history import FileHistory +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout import ConditionalContainer, FormattedTextControl, HSplit, Layout, Window +from prompt_toolkit.layout.dimension import Dimension +from prompt_toolkit.layout.menus import CompletionsMenu +from prompt_toolkit.layout.processors import ConditionalProcessor, PasswordProcessor, Processor, Transformation +from prompt_toolkit.patch_stdout import patch_stdout +from prompt_toolkit.styles import Style as PTStyle +from prompt_toolkit.widgets import TextArea + from hermes_constants import OPENROUTER_BASE_URL _hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) _user_env = _hermes_home / ".env" -_project_env = Path(__file__).parent / '.env' +_project_env = Path(__file__).parent / ".env" if _user_env.exists(): try: load_dotenv(dotenv_path=_user_env, encoding="utf-8") @@ -74,12 +75,13 @@ os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(_hermes_home)) # Configuration Loading # ============================================================================= -def _load_prefill_messages(file_path: str) -> List[Dict[str, Any]]: + +def _load_prefill_messages(file_path: str) -> list[dict[str, Any]]: """Load ephemeral prefill messages from a JSON file. - + The file should contain a JSON array of {role, content} dicts, e.g.: [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}] - + Relative paths are resolved from ~/.hermes/. Returns an empty list if the path is empty or the file doesn't exist. """ @@ -92,7 +94,7 @@ def _load_prefill_messages(file_path: str) -> List[Dict[str, Any]]: logger.warning("Prefill messages file not found: %s", path) return [] try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list): logger.warning("Prefill messages file must contain a JSON array: %s", path) @@ -105,7 +107,7 @@ def _load_prefill_messages(file_path: str) -> List[Dict[str, Any]]: def _parse_reasoning_config(effort: str) -> dict | None: """Parse a reasoning effort level into an OpenRouter reasoning config dict. - + Valid levels: "xhigh", "high", "medium", "low", "minimal", "none". Returns None to use the default (medium), or a config dict to override. """ @@ -121,27 +123,27 @@ def _parse_reasoning_config(effort: str) -> dict | None: return None -def load_cli_config() -> Dict[str, Any]: +def load_cli_config() -> dict[str, Any]: """ Load CLI configuration from config files. - + Config lookup order: 1. ~/.hermes/config.yaml (user config - preferred) 2. ./cli-config.yaml (project config - fallback) - + Environment variables take precedence over config file values. Returns default values if no config file exists. """ # Check user config first (~/.hermes/config.yaml) - user_config_path = Path.home() / '.hermes' / 'config.yaml' - project_config_path = Path(__file__).parent / 'cli-config.yaml' - + user_config_path = Path.home() / ".hermes" / "config.yaml" + project_config_path = Path(__file__).parent / "cli-config.yaml" + # Use user config if it exists, otherwise project config if user_config_path.exists(): config_path = user_config_path else: config_path = project_config_path - + # Default configuration defaults = { "model": { @@ -165,8 +167,8 @@ def load_cli_config() -> Dict[str, Any]: "record_sessions": False, # Auto-record browser sessions as WebM videos }, "compression": { - "enabled": True, # Auto-compress when approaching context limit - "threshold": 0.85, # Compress at 85% of model's context limit + "enabled": True, # Auto-compress when approaching context limit + "threshold": 0.85, # Compress at 85% of model's context limit "summary_model": "google/gemini-3-flash-preview", # Fast/cheap model for summaries }, "agent": { @@ -201,7 +203,7 @@ def load_cli_config() -> Dict[str, Any]: "timeout": 120, # Seconds to wait for a clarify answer before auto-proceeding }, "code_execution": { - "timeout": 300, # Max seconds a sandbox script can run before being killed (5 min) + "timeout": 300, # Max seconds a sandbox script can run before being killed (5 min) "max_tool_calls": 50, # Max RPC tool calls per execution }, "delegation": { @@ -209,7 +211,7 @@ def load_cli_config() -> Dict[str, Any]: "default_toolsets": ["terminal", "file", "web"], # Default toolsets for subagents }, } - + # Track whether the config file explicitly set terminal config. # When using defaults (no config file / no terminal section), we should NOT # overwrite env vars that were already set by .env -- only a user's config @@ -219,9 +221,9 @@ def load_cli_config() -> Dict[str, Any]: # Load from file if exists if config_path.exists(): try: - with open(config_path, "r") as f: + with open(config_path) as f: file_config = yaml.safe_load(f) or {} - + _file_has_terminal_config = "terminal" in file_config # Handle model config - can be string (new format) or dict (old format) @@ -232,7 +234,7 @@ def load_cli_config() -> Dict[str, Any]: elif isinstance(file_config["model"], dict): # Old format: model is a dict with default/base_url defaults["model"].update(file_config["model"]) - + # Deep merge file_config into defaults. # First: merge keys that exist in both (deep-merge dicts, overwrite scalars) for key in defaults: @@ -243,28 +245,28 @@ def load_cli_config() -> Dict[str, Any]: defaults[key].update(file_config[key]) else: defaults[key] = file_config[key] - + # Second: carry over keys from file_config that aren't in defaults # (e.g. platform_toolsets, provider_routing, memory, honcho, etc.) for key in file_config: if key not in defaults and key != "model": defaults[key] = file_config[key] - + # Handle root-level max_turns (backwards compat) - copy to agent.max_turns if "max_turns" in file_config and "agent" not in file_config: defaults["agent"]["max_turns"] = file_config["max_turns"] except Exception as e: logger.warning("Failed to load cli-config.yaml: %s", e) - + # Apply terminal config to environment variables (so terminal_tool picks them up) terminal_config = defaults.get("terminal", {}) - + # Normalize config key: the new config system (hermes_cli/config.py) and all # documentation use "backend", the legacy cli-config.yaml uses "env_type". # Accept both, with "backend" taking precedence (it's the documented key). if "backend" in terminal_config: terminal_config["env_type"] = terminal_config["backend"] - + # Handle special cwd values: "." or "auto" means use current working directory. # Only resolve to the host's CWD for the local backend where the host # filesystem is directly accessible. For ALL remote/container backends @@ -278,7 +280,7 @@ def load_cli_config() -> Dict[str, Any]: else: # Remove so TERMINAL_CWD stays unset → tool picks backend default terminal_config.pop("cwd", None) - + env_mappings = { "env_type": "TERMINAL_ENV", "cwd": "TERMINAL_CWD", @@ -303,7 +305,7 @@ def load_cli_config() -> Dict[str, Any]: # Sudo support (works with all backends) "sudo_password": "SUDO_PASSWORD", } - + # Apply config values to env vars so terminal_tool picks them up. # If the config file explicitly has a [terminal] section, those values are # authoritative and override any .env settings. When using defaults only @@ -315,20 +317,21 @@ def load_cli_config() -> Dict[str, Any]: val = terminal_config[config_key] if isinstance(val, list): import json + os.environ[env_var] = json.dumps(val) else: os.environ[env_var] = str(val) - + # Apply browser config to environment variables browser_config = defaults.get("browser", {}) browser_env_mappings = { "inactivity_timeout": "BROWSER_INACTIVITY_TIMEOUT", } - + for config_key, env_var in browser_env_mappings.items(): if config_key in browser_config: os.environ[env_var] = str(browser_config[config_key]) - + # Apply compression config to environment variables compression_config = defaults.get("compression", {}) compression_env_mappings = { @@ -337,11 +340,11 @@ def load_cli_config() -> Dict[str, Any]: "summary_model": "CONTEXT_COMPRESSION_MODEL", "summary_provider": "CONTEXT_COMPRESSION_PROVIDER", } - + for config_key, env_var in compression_env_mappings.items(): if config_key in compression_config: os.environ[env_var] = str(compression_config[config_key]) - + # Apply auxiliary model overrides to environment variables. # Vision and web_extract each have their own provider + model pair. # (Compression is handled in the compression section above.) @@ -350,10 +353,10 @@ def load_cli_config() -> Dict[str, Any]: auxiliary_config = defaults.get("auxiliary", {}) auxiliary_task_env = { # config key → (provider env var, model env var) - "vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"), - "web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"), + "vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"), + "web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"), } - + for task_key, (prov_env, model_env) in auxiliary_task_env.items(): task_cfg = auxiliary_config.get(task_key, {}) if not isinstance(task_cfg, dict): @@ -364,7 +367,7 @@ def load_cli_config() -> Dict[str, Any]: os.environ[prov_env] = prov if model: os.environ[model_env] = model - + # Security settings security_config = defaults.get("security", {}) if isinstance(security_config, dict): @@ -374,41 +377,52 @@ def load_cli_config() -> Dict[str, Any]: return defaults + # Load configuration at module startup CLI_CONFIG = load_cli_config() +import fire from rich.console import Console from rich.panel import Panel from rich.table import Table -import fire - -# Import the agent and tool systems -from run_agent import AIAgent -from model_tools import get_tool_definitions, get_toolset_for_tool +# Cron job system for scheduled tasks (CRUD only — execution is handled by the gateway) +from cron import create_job, get_job, list_jobs, remove_job +from hermes_cli.banner import ( + _BOLD, + _DIM, + _GOLD, + _RST, + COMPACT_BANNER, + HERMES_AGENT_LOGO, + HERMES_CADUCEUS, + VERSION, + build_welcome_banner, +) # Extracted CLI modules (Phase 3) from hermes_cli.banner import ( - cprint as _cprint, _GOLD, _BOLD, _DIM, _RST, - VERSION, HERMES_AGENT_LOGO, HERMES_CADUCEUS, COMPACT_BANNER, + cprint as _cprint, +) +from hermes_cli.banner import ( get_available_skills as _get_available_skills, - build_welcome_banner, ) from hermes_cli.commands import COMMANDS, SlashCommandCompleter -from hermes_cli import callbacks as _callbacks -from toolsets import get_all_toolsets, get_toolset_info, resolve_toolset, validate_toolset +from model_tools import get_tool_definitions, get_toolset_for_tool -# Cron job system for scheduled tasks (CRUD only — execution is handled by the gateway) -from cron import create_job, list_jobs, remove_job, get_job +# Import the agent and tool systems +from run_agent import AIAgent +from tools.browser_tool import _emergency_cleanup_all_sessions as _cleanup_all_browsers # Resource cleanup imports for safe shutdown (terminal VMs, browser sessions) from tools.terminal_tool import cleanup_all_environments as _cleanup_all_terminals -from tools.terminal_tool import set_sudo_password_callback, set_approval_callback -from tools.browser_tool import _emergency_cleanup_all_sessions as _cleanup_all_browsers +from tools.terminal_tool import set_approval_callback, set_sudo_password_callback +from toolsets import get_all_toolsets, get_toolset_info, validate_toolset # Guard to prevent cleanup from running multiple times on exit _cleanup_done = False + def _run_cleanup(): """Run resource cleanup exactly once.""" global _cleanup_done @@ -425,6 +439,7 @@ def _run_cleanup(): pass try: from tools.mcp_tool import shutdown_mcp_servers + shutdown_mcp_servers() except Exception: pass @@ -435,16 +450,19 @@ def _run_cleanup(): # ============================================================================= # Tracks the active worktree for cleanup on exit -_active_worktree: Optional[Dict[str, str]] = None +_active_worktree: dict[str, str] | None = None -def _git_repo_root() -> Optional[str]: +def _git_repo_root() -> str | None: """Return the git repo root for CWD, or None if not in a repo.""" import subprocess + try: result = subprocess.run( ["git", "rev-parse", "--show-toplevel"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if result.returncode == 0: return result.stdout.strip() @@ -453,7 +471,7 @@ def _git_repo_root() -> Optional[str]: return None -def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]: +def _setup_worktree(repo_root: str = None) -> dict[str, str] | None: """Create an isolated git worktree for this CLI session. Returns a dict with worktree metadata on success, None on failure. @@ -493,7 +511,10 @@ def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]: try: result = subprocess.run( ["git", "worktree", "add", str(wt_path), "-b", branch_name, "HEAD"], - capture_output=True, text=True, timeout=30, cwd=repo_root, + capture_output=True, + text=True, + timeout=30, + cwd=repo_root, ) if result.returncode != 0: print(f"\033[31m✗ Failed to create worktree: {result.stderr.strip()}\033[0m") @@ -535,7 +556,7 @@ def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]: return info -def _cleanup_worktree(info: Dict[str, str] = None) -> None: +def _cleanup_worktree(info: dict[str, str] = None) -> None: """Remove a worktree and its branch on exit. If the worktree has uncommitted changes, warn and keep it. @@ -558,7 +579,10 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None: try: status = subprocess.run( ["git", "status", "--porcelain"], - capture_output=True, text=True, timeout=10, cwd=wt_path, + capture_output=True, + text=True, + timeout=10, + cwd=wt_path, ) has_changes = bool(status.stdout.strip()) except Exception: @@ -574,7 +598,10 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None: try: subprocess.run( ["git", "worktree", "remove", wt_path, "--force"], - capture_output=True, text=True, timeout=15, cwd=repo_root, + capture_output=True, + text=True, + timeout=15, + cwd=repo_root, ) except Exception as e: logger.debug("Failed to remove worktree: %s", e) @@ -583,7 +610,10 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None: try: subprocess.run( ["git", "branch", "-D", branch], - capture_output=True, text=True, timeout=10, cwd=repo_root, + capture_output=True, + text=True, + timeout=10, + cwd=repo_root, ) except Exception as e: logger.debug("Failed to delete branch %s: %s", branch, e) @@ -623,7 +653,10 @@ def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None: try: status = subprocess.run( ["git", "status", "--porcelain"], - capture_output=True, text=True, timeout=5, cwd=str(entry), + capture_output=True, + text=True, + timeout=5, + cwd=str(entry), ) if status.stdout.strip(): continue # Has changes — skip @@ -634,23 +667,33 @@ def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None: try: branch_result = subprocess.run( ["git", "branch", "--show-current"], - capture_output=True, text=True, timeout=5, cwd=str(entry), + capture_output=True, + text=True, + timeout=5, + cwd=str(entry), ) branch = branch_result.stdout.strip() subprocess.run( ["git", "worktree", "remove", str(entry), "--force"], - capture_output=True, text=True, timeout=15, cwd=repo_root, + capture_output=True, + text=True, + timeout=15, + cwd=repo_root, ) if branch: subprocess.run( ["git", "branch", "-D", branch], - capture_output=True, text=True, timeout=10, cwd=repo_root, + capture_output=True, + text=True, + timeout=10, + cwd=repo_root, ) logger.debug("Pruned stale worktree: %s", entry.name) except Exception as e: logger.debug("Failed to prune worktree %s: %s", entry.name, e) + # ============================================================================ # ASCII Art & Branding # ============================================================================ @@ -663,11 +706,12 @@ def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None: # - Dim: #B8860B (muted text) # ANSI building blocks for conversation display -_GOLD = "\033[1;33m" # Bold yellow — closest universal match to the gold theme +_GOLD = "\033[1;33m" # Bold yellow — closest universal match to the gold theme _BOLD = "\033[1m" _DIM = "\033[2m" _RST = "\033[0m" + def _cprint(text: str): """Print ANSI-colored text through prompt_toolkit's native renderer. @@ -689,6 +733,7 @@ class ChatConsole: def __init__(self): from io import StringIO + self._buffer = StringIO() self._inner = Console(file=self._buffer, force_terminal=True, highlight=False) @@ -700,6 +745,7 @@ class ChatConsole: for line in output.rstrip("\n").split("\n"): _cprint(line) + # ASCII Art - HERMES-AGENT logo (full width, single line - requires ~95 char terminal) HERMES_AGENT_LOGO = """[bold #FFD700]██╗ ██╗███████╗██████╗ ███╗ ███╗███████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/] [bold #FFD700]██║ ██║██╔════╝██╔══██╗████╗ ████║██╔════╝██╔════╝ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝[/] @@ -745,8 +791,8 @@ def _build_compact_banner() -> str: line1 = "⚕ NOUS HERMES - AI Agent Framework" line2 = "Messenger of the Digital Gods · Nous Research" # Truncate and pad to fit - line1 = line1[:inner - 2].ljust(inner - 2) - line2 = line2[:inner - 2].ljust(inner - 2) + line1 = line1[: inner - 2].ljust(inner - 2) + line2 = line2[: inner - 2].ljust(inner - 2) return ( f"\n[bold #FFD700]╔{bar}╗[/]\n" f"[bold #FFD700]║[/] [#FFBF00]{line1}[/] [bold #FFD700]║[/]\n" @@ -755,35 +801,35 @@ def _build_compact_banner() -> str: ) -def _get_available_skills() -> Dict[str, List[str]]: +def _get_available_skills() -> dict[str, list[str]]: """ Scan ~/.hermes/skills/ and return skills grouped by category. - + Returns: Dict mapping category name to list of skill names """ import os - + hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) skills_dir = hermes_home / "skills" skills_by_category = {} - + if not skills_dir.exists(): return skills_by_category - + for skill_file in skills_dir.rglob("SKILL.md"): rel_path = skill_file.relative_to(skills_dir) parts = rel_path.parts - + if len(parts) >= 2: category = parts[0] skill_name = parts[-2] else: category = "general" skill_name = skill_file.parent.name - + skills_by_category.setdefault(category, []).append(skill_name) - + return skills_by_category @@ -798,10 +844,18 @@ def _format_context_length(tokens: int) -> str: return str(tokens) -def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, context_length: int = None): +def build_welcome_banner( + console: Console, + model: str, + cwd: str, + tools: list[dict] = None, + enabled_toolsets: list[str] = None, + session_id: str = None, + context_length: int = None, +): """ Build and print a Claude Code-style welcome banner with caduceus on left and info on right. - + Args: console: Rich Console instance for printing model: The current model name (e.g., "anthropic/claude-opus-4") @@ -811,46 +865,48 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic session_id: Unique session identifier for logging context_length: Model's context window size in tokens """ - from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS - + from model_tools import check_tool_availability + tools = tools or [] enabled_toolsets = enabled_toolsets or [] - + # Get unavailable tools info for coloring _, unavailable_toolsets = check_tool_availability(quiet=True) disabled_tools = set() for item in unavailable_toolsets: disabled_tools.update(item.get("tools", [])) - + # Build the side-by-side content using a table for precise control layout_table = Table.grid(padding=(0, 2)) layout_table.add_column("left", justify="center") layout_table.add_column("right", justify="left") - + # Build left content: caduceus + model info left_lines = ["", HERMES_CADUCEUS, ""] - + # Shorten model name for display model_short = model.split("/")[-1] if "/" in model else model if len(model_short) > 28: model_short = model_short[:25] + "..." - - ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + + ctx_str = ( + f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + ) left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") - + # Add session ID if provided if session_id: left_lines.append(f"[dim #8B8682]Session: {session_id}[/]") left_content = "\n".join(left_lines) - + # Build right content: tools list grouped by toolset right_lines = [] right_lines.append("[bold #FFBF00]Available Tools[/]") - + # Group tools by toolset (include all possible tools, both enabled and disabled) toolsets_dict = {} - + # First, add all enabled tools for tool in tools: tool_name = tool["function"]["name"] @@ -858,7 +914,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic if toolset not in toolsets_dict: toolsets_dict[toolset] = [] toolsets_dict[toolset].append(tool_name) - + # Also add disabled toolsets so they show in the banner for item in unavailable_toolsets: # Map the internal toolset ID to display name @@ -869,12 +925,12 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic for tool_name in item.get("tools", []): if tool_name not in toolsets_dict[display_name]: toolsets_dict[display_name].append(tool_name) - + # Display tools grouped by toolset (compact format, max 8 groups) sorted_toolsets = sorted(toolsets_dict.keys()) display_toolsets = sorted_toolsets[:8] remaining_toolsets = len(sorted_toolsets) - 8 - + for toolset in display_toolsets: tool_names = toolsets_dict[toolset] # Color each tool name - red if disabled, normal if enabled @@ -884,7 +940,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic colored_names.append(f"[red]{name}[/]") else: colored_names.append(f"[#FFF8DC]{name}[/]") - + tools_str = ", ".join(colored_names) # Truncate if too long (accounting for markup) if len(", ".join(sorted(tool_names))) > 45: @@ -907,19 +963,19 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic else: colored_names.append(f"[#FFF8DC]{name}[/]") tools_str = ", ".join(colored_names) - + right_lines.append(f"[dim #B8860B]{toolset}:[/] {tools_str}") - + if remaining_toolsets > 0: right_lines.append(f"[dim #B8860B](and {remaining_toolsets} more toolsets...)[/]") - + right_lines.append("") - + # Add skills section right_lines.append("[bold #FFBF00]Available Skills[/]") skills_by_category = _get_available_skills() total_skills = sum(len(s) for s in skills_by_category.values()) - + if skills_by_category: for category in sorted(skills_by_category.keys()): skill_names = sorted(skills_by_category[category]) @@ -935,15 +991,15 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic right_lines.append(f"[dim #B8860B]{category}:[/] [#FFF8DC]{skills_str}[/]") else: right_lines.append("[dim #B8860B]No skills installed[/]") - + right_lines.append("") right_lines.append(f"[dim #B8860B]{len(tools)} tools · {total_skills} skills · /help for commands[/]") - + right_content = "\n".join(right_lines) - + # Add to table layout_table.add_row(left_content, right_content) - + # Wrap in a panel with the title outer_panel = Panel( layout_table, @@ -951,14 +1007,14 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic border_style="#CD7F32", padding=(0, 2), ) - + # Print the big HERMES-AGENT logo — skip if terminal is too narrow console.print() term_width = shutil.get_terminal_size().columns if term_width >= 95: console.print(HERMES_AGENT_LOGO) console.print() - + # Print the panel with caduceus and info console.print(outer_panel) @@ -967,7 +1023,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic # Skill Slash Commands — dynamic commands generated from installed skills # ============================================================================ -from agent.skill_commands import scan_skill_commands, get_skill_commands, build_skill_invocation_message +from agent.skill_commands import build_skill_invocation_message, scan_skill_commands _skill_commands = scan_skill_commands() @@ -975,47 +1031,47 @@ _skill_commands = scan_skill_commands() def save_config_value(key_path: str, value: any) -> bool: """ Save a value to the active config file at the specified key path. - + Respects the same lookup order as load_cli_config(): 1. ~/.hermes/config.yaml (user config - preferred, used if it exists) 2. ./cli-config.yaml (project config - fallback) - + Args: key_path: Dot-separated path like "agent.system_prompt" value: Value to save - + Returns: True if successful, False otherwise """ # Use the same precedence as load_cli_config: user config first, then project config - user_config_path = Path.home() / '.hermes' / 'config.yaml' - project_config_path = Path(__file__).parent / 'cli-config.yaml' + user_config_path = Path.home() / ".hermes" / "config.yaml" + project_config_path = Path(__file__).parent / "cli-config.yaml" config_path = user_config_path if user_config_path.exists() else project_config_path - + try: # Ensure parent directory exists (for ~/.hermes/config.yaml on first use) config_path.parent.mkdir(parents=True, exist_ok=True) - + # Load existing config if config_path.exists(): - with open(config_path, 'r') as f: + with open(config_path) as f: config = yaml.safe_load(f) or {} else: config = {} - + # Navigate to the key and set value - keys = key_path.split('.') + keys = key_path.split(".") current = config for key in keys[:-1]: if key not in current or not isinstance(current[key], dict): current[key] = {} current = current[key] current[keys[-1]] = value - + # Save back - with open(config_path, 'w') as f: + with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - + return True except Exception as e: logger.error("Failed to save config: %s", e) @@ -1026,18 +1082,19 @@ def save_config_value(key_path: str, value: any) -> bool: # HermesCLI Class # ============================================================================ + class HermesCLI: """ Interactive CLI for the Hermes Agent. - + Provides a REPL interface with rich formatting, command history, and tool execution capabilities. """ - + def __init__( self, model: str = None, - toolsets: List[str] = None, + toolsets: list[str] = None, provider: str = None, api_key: str = None, base_url: str = None, @@ -1070,7 +1127,7 @@ class HermesCLI: # bell_on_complete: play terminal bell (\a) when agent finishes a response self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False) self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose") - + # Configuration - priority: CLI args > env vars > config file # Model can come from: CLI arg, LLM_MODEL env, OPENAI_MODEL env (custom endpoint), or config self.model = model or os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or CLI_CONFIG["model"]["default"] @@ -1084,12 +1141,9 @@ class HermesCLI: # Provider selection is resolved lazily at use-time via _ensure_runtime_credentials(). self.requested_provider = ( - provider - or os.getenv("HERMES_INFERENCE_PROVIDER") - or CLI_CONFIG["model"].get("provider") - or "auto" + provider or os.getenv("HERMES_INFERENCE_PROVIDER") or CLI_CONFIG["model"].get("provider") or "auto" ) - self._provider_source: Optional[str] = None + self._provider_source: str | None = None self.provider = self.requested_provider self.api_mode = "chat_completions" self.base_url = ( @@ -1104,8 +1158,8 @@ class HermesCLI: self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") else: self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") - self._nous_key_expires_at: Optional[str] = None - self._nous_key_source: Optional[str] = None + self._nous_key_expires_at: str | None = None + self._nous_key_source: str | None = None # Max turns priority: CLI arg > config file > env var > default if max_turns is not None: # CLI arg was explicitly set self.max_turns = max_turns @@ -1117,7 +1171,7 @@ class HermesCLI: self.max_turns = int(os.getenv("HERMES_MAX_ITERATIONS")) else: self.max_turns = 90 - + # Parse and validate toolsets self.enabled_toolsets = toolsets if toolsets and "all" not in toolsets and "*" not in toolsets: @@ -1125,24 +1179,19 @@ class HermesCLI: invalid = [t for t in toolsets if not validate_toolset(t)] if invalid: self.console.print(f"[bold red]Warning: Unknown toolsets: {', '.join(invalid)}[/]") - + # Ephemeral system prompt: env var takes precedence, then config - self.system_prompt = ( - os.getenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", "") - or CLI_CONFIG["agent"].get("system_prompt", "") + self.system_prompt = os.getenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", "") or CLI_CONFIG["agent"].get( + "system_prompt", "" ) self.personalities = CLI_CONFIG["agent"].get("personalities", {}) - + # Ephemeral prefill messages (few-shot priming, never persisted) - self.prefill_messages = _load_prefill_messages( - CLI_CONFIG["agent"].get("prefill_messages_file", "") - ) - + self.prefill_messages = _load_prefill_messages(CLI_CONFIG["agent"].get("prefill_messages_file", "")) + # Reasoning config (OpenRouter reasoning effort level) - self.reasoning_config = _parse_reasoning_config( - CLI_CONFIG["agent"].get("reasoning_effort", "") - ) - + self.reasoning_config = _parse_reasoning_config(CLI_CONFIG["agent"].get("reasoning_effort", "")) + # OpenRouter provider routing preferences pr = CLI_CONFIG.get("provider_routing", {}) or {} self._provider_sort = pr.get("sort") @@ -1151,30 +1200,31 @@ class HermesCLI: self._providers_order = pr.get("order") self._provider_require_params = pr.get("require_parameters", False) self._provider_data_collection = pr.get("data_collection") - + # Fallback model config — tried when primary provider fails after retries fb = CLI_CONFIG.get("fallback_model") or {} self._fallback_model = fb if fb.get("provider") and fb.get("model") else None # Agent will be initialized on first use - self.agent: Optional[AIAgent] = None + self.agent: AIAgent | None = None self._app = None # prompt_toolkit Application (set in run()) - + # Conversation state - self.conversation_history: List[Dict[str, Any]] = [] + self.conversation_history: list[dict[str, Any]] = [] self.session_start = datetime.now() self._resumed = False # Initialize SQLite session store early so /title works before first message self._session_db = None try: from hermes_state import SessionDB + self._session_db = SessionDB() except Exception: pass - + # Deferred title: stored in memory until the session is created in the DB - self._pending_title: Optional[str] = None - + self._pending_title: str | None = None + # Session ID: reuse existing one when resuming, otherwise generate fresh if resume: self.session_id = resume @@ -1183,7 +1233,7 @@ class HermesCLI: timestamp_str = self.session_start.strftime("%Y%m%d_%H%M%S") short_uuid = uuid.uuid4().hex[:6] self.session_id = f"{timestamp_str}_{short_uuid}" - + # History file for persistent input recall across sessions self._history_file = Path.home() / ".hermes_history" self._last_invalidate: float = 0.0 # throttle UI repaints @@ -1191,6 +1241,7 @@ class HermesCLI: def _invalidate(self, min_interval: float = 0.25) -> None: """Throttled UI repaint — prevents terminal blinking on slow/SSH connections.""" import time as _time + now = _time.monotonic() if hasattr(self, "_app") and self._app and (now - self._last_invalidate) >= min_interval: self._last_invalidate = now @@ -1223,8 +1274,7 @@ class HermesCLI: slug = current_model.split("/", 1)[1] if not self._model_is_default: self.console.print( - f"[yellow]⚠️ Stripped provider prefix from '{current_model}'; " - f"using '{slug}' for OpenAI Codex.[/]" + f"[yellow]⚠️ Stripped provider prefix from '{current_model}'; using '{slug}' for OpenAI Codex.[/]" ) self.model = slug current_model = slug @@ -1258,8 +1308,8 @@ class HermesCLI: Returns True if credentials are ready, False on auth failure. """ from hermes_cli.runtime_provider import ( - resolve_runtime_provider, format_runtime_provider_error, + resolve_runtime_provider, ) try: @@ -1285,10 +1335,7 @@ class HermesCLI: return False credentials_changed = api_key != self.api_key or base_url != self.base_url - routing_changed = ( - resolved_provider != self.provider - or resolved_api_mode != self.api_mode - ) + routing_changed = resolved_provider != self.provider or resolved_api_mode != self.api_mode self.provider = resolved_provider self.api_mode = resolved_api_mode self._provider_source = runtime.get("source") @@ -1310,7 +1357,7 @@ class HermesCLI: """ Initialize the agent on first use. When resuming a session, restores conversation history from SQLite. - + Returns: bool: True if successful, False otherwise """ @@ -1324,10 +1371,11 @@ class HermesCLI: if self._session_db is None: try: from hermes_state import SessionDB + self._session_db = SessionDB() except Exception as e: logger.debug("SQLite session store not available: %s", e) - + # If resuming, validate the session exists and load its history. # _preload_resumed_session() may have already loaded it (called from # run() for immediate display). In that case, conversation_history @@ -1344,7 +1392,7 @@ class HermesCLI: msg_count = len([m for m in restored if m.get("role") == "user"]) title_part = "" if session_meta.get("title"): - title_part = f" \"{session_meta['title']}\"" + title_part = f' "{session_meta["title"]}"' _cprint( f"{_GOLD}↻ Resumed session {_BOLD}{self.session_id}{_RST}{_GOLD}{title_part} " f"({msg_count} user message{'s' if msg_count != 1 else ''}, " @@ -1361,7 +1409,7 @@ class HermesCLI: self._session_db._conn.commit() except Exception: pass - + try: self.agent = AIAgent( model=self.model, @@ -1402,31 +1450,31 @@ class HermesCLI: except Exception as e: self.console.print(f"[bold red]Failed to initialize agent: {e}[/]") return False - + def show_banner(self): """Display the welcome banner in Claude Code style.""" self.console.clear() - + # Auto-compact for narrow terminals — the full banner with caduceus # + tool list needs ~80 columns minimum to render without wrapping. term_width = shutil.get_terminal_size().columns use_compact = self.compact or term_width < 80 - + if use_compact: self.console.print(_build_compact_banner()) self._show_status() else: # Get tools for display tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) - + # Get terminal working directory (where commands will execute) cwd = os.getenv("TERMINAL_CWD", os.getcwd()) - + # Get context length for display ctx_len = None - if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'): + if hasattr(self, "agent") and self.agent and hasattr(self.agent, "context_compressor"): ctx_len = self.agent.context_compressor.context_length - + # Build and display the banner build_welcome_banner( console=self.console, @@ -1437,10 +1485,10 @@ class HermesCLI: session_id=self.session_id, context_length=ctx_len, ) - + # Show tool availability warnings if any tools are disabled self._show_tool_availability_warnings() - + self.console.print() def _preload_resumed_session(self) -> bool: @@ -1459,13 +1507,8 @@ class HermesCLI: session_meta = self._session_db.get_session(self.session_id) if not session_meta: - self.console.print( - f"[bold red]Session not found: {self.session_id}[/]" - ) - self.console.print( - "[dim]Use a session ID from a previous CLI run " - "(hermes sessions list).[/]" - ) + self.console.print(f"[bold red]Session not found: {self.session_id}[/]") + self.console.print("[dim]Use a session ID from a previous CLI run (hermes sessions list).[/]") return False restored = self._session_db.get_messages_as_conversation(self.session_id) @@ -1482,17 +1525,13 @@ class HermesCLI: f"{len(restored)} total messages)[/]" ) else: - self.console.print( - f"[#DAA520]Session {self.session_id} found but has no " - f"messages. Starting fresh.[/]" - ) + self.console.print(f"[#DAA520]Session {self.session_id} found but has no messages. Starting fresh.[/]") return False # Re-open the session (clear ended_at so it's active again) try: self._session_db._conn.execute( - "UPDATE sessions SET ended_at = NULL, end_reason = NULL " - "WHERE id = ?", + "UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?", (self.session_id,), ) self._session_db._conn.commit() @@ -1516,23 +1555,28 @@ class HermesCLI: if self.resume_display == "minimal": return - MAX_DISPLAY_EXCHANGES = 10 # max user+assistant pairs to show - MAX_USER_LEN = 300 # truncate user messages - MAX_ASST_LEN = 200 # truncate assistant text - MAX_ASST_LINES = 3 # max lines of assistant text + MAX_DISPLAY_EXCHANGES = 10 # max user+assistant pairs to show + MAX_USER_LEN = 300 # truncate user messages + MAX_ASST_LEN = 200 # truncate assistant text + MAX_ASST_LINES = 3 # max lines of assistant text def _strip_reasoning(text: str) -> str: """Remove ... blocks from displayed text (reasoning model internal thoughts).""" import re + cleaned = re.sub( r".*?\s*", - "", text, flags=re.DOTALL, + "", + text, + flags=re.DOTALL, ) # Also strip unclosed reasoning tags at the end cleaned = re.sub( r".*$", - "", cleaned, flags=re.DOTALL, + "", + cleaned, + flags=re.DOTALL, ) return cleaned.strip() @@ -1665,6 +1709,7 @@ class HermesCLI: Windows Terminal with WSL2). """ from hermes_cli.clipboard import has_clipboard_image + if has_clipboard_image(): if self._try_attach_clipboard_image(): n = len(self._attached_images) @@ -1688,6 +1733,7 @@ class HermesCLI: """ import asyncio as _asyncio import json as _json + from tools.vision_tools import vision_analyze_tool analysis_prompt = ( @@ -1703,9 +1749,7 @@ class HermesCLI: size_kb = img_path.stat().st_size // 1024 _cprint(f" {_DIM}👁️ analyzing {img_path.name} ({size_kb}KB)...{_RST}") try: - result_json = _asyncio.run( - vision_analyze_tool(image_url=str(img_path), user_prompt=analysis_prompt) - ) + result_json = _asyncio.run(vision_analyze_tool(image_url=str(img_path), user_prompt=analysis_prompt)) result = _json.loads(result_json) if result.get("success"): description = result.get("analysis", "") @@ -1740,42 +1784,44 @@ class HermesCLI: def _show_tool_availability_warnings(self): """Show warnings about disabled tools due to missing API keys.""" try: - from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS - + from model_tools import check_tool_availability + available, unavailable = check_tool_availability() - + # Filter to only those missing API keys (not system deps) api_key_missing = [u for u in unavailable if u["missing_vars"]] - + if api_key_missing: self.console.print() self.console.print("[yellow]⚠️ Some tools disabled (missing API keys):[/]") for item in api_key_missing: tools_str = ", ".join(item["tools"][:2]) # Show first 2 tools if len(item["tools"]) > 2: - tools_str += f", +{len(item['tools'])-2} more" - self.console.print(f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['missing_vars'])})[/]") + tools_str += f", +{len(item['tools']) - 2} more" + self.console.print( + f" [dim]• {item['name']}[/] [dim italic]({', '.join(item['missing_vars'])})[/]" + ) self.console.print("[dim] Run 'hermes setup' to configure[/]") except Exception: pass # Don't crash on import errors - + def _show_status(self): """Show current status bar.""" # Get tool count tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) tool_count = len(tools) if tools else 0 - + # Format model name (shorten if needed) model_short = self.model.split("/")[-1] if "/" in self.model else self.model if len(model_short) > 30: model_short = model_short[:27] + "..." - + # Get API status indicator if self.api_key: api_indicator = "[green bold]●[/]" else: api_indicator = "[red bold]●[/]" - + # Build status line with proper markup toolsets_info = "" if self.enabled_toolsets and "all" not in self.enabled_toolsets: @@ -1790,16 +1836,16 @@ class HermesCLI: f"[dim #B8860B]·[/] [bold cyan]{tool_count} tools[/]" f"{toolsets_info}{provider_info}" ) - + def show_help(self): """Display help information.""" _cprint(f"\n{_BOLD}+{'-' * 50}+{_RST}") _cprint(f"{_BOLD}|{' ' * 14}(^_^)? Available Commands{' ' * 10}|{_RST}") _cprint(f"{_BOLD}+{'-' * 50}+{_RST}\n") - + for cmd, desc in COMMANDS.items(): _cprint(f" {_GOLD}{cmd:<15}{_RST} {_DIM}-{_RST} {desc}") - + if _skill_commands: _cprint(f"\n ⚡ {_BOLD}Skill Commands{_RST} ({len(_skill_commands)} installed):") for cmd, info in sorted(_skill_commands.items()): @@ -1808,15 +1854,15 @@ class HermesCLI: _cprint(f"\n {_DIM}Tip: Just type your message to chat with Hermes!{_RST}") _cprint(f" {_DIM}Multi-line: Alt+Enter for a new line{_RST}") _cprint(f" {_DIM}Paste image: Alt+V (or /paste){_RST}\n") - + def show_tools(self): """Display available tools with kawaii ASCII art.""" tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) - + if not tools: print("(;_;) No tools available") return - + # Header print() title = "(^_^)/ Available Tools" @@ -1826,7 +1872,7 @@ class HermesCLI: print("|" + " " * (pad // 2) + title + " " * (pad - pad // 2) + "|") print("+" + "-" * width + "+") print() - + # Group tools by toolset toolsets = {} for tool in sorted(tools, key=lambda t: t["function"]["name"]): @@ -1838,23 +1884,23 @@ class HermesCLI: # First sentence: split on ". " (period+space) to avoid breaking on "e.g." or "v2.0" desc = desc.split("\n")[0] if ". " in desc: - desc = desc[:desc.index(". ") + 1] + desc = desc[: desc.index(". ") + 1] toolsets[toolset].append((name, desc)) - + # Display by toolset for toolset in sorted(toolsets.keys()): print(f" [{toolset}]") for name, desc in toolsets[toolset]: print(f" * {name:<20} - {desc}") print() - + print(f" Total: {len(tools)} tools ヽ(^o^)ノ") print() - + def show_toolsets(self): """Display available toolsets with kawaii ASCII art.""" all_toolsets = get_all_toolsets() - + # Header print() title = "(^_^)b Available Toolsets" @@ -1864,41 +1910,41 @@ class HermesCLI: print("|" + " " * (pad // 2) + title + " " * (pad - pad // 2) + "|") print("+" + "-" * width + "+") print() - + for name in sorted(all_toolsets.keys()): info = get_toolset_info(name) if info: tool_count = info["tool_count"] desc = info["description"] - + # Mark if currently enabled marker = "(*)" if self.enabled_toolsets and name in self.enabled_toolsets else " " print(f" {marker} {name:<18} [{tool_count:>2} tools] - {desc}") - + print() print(" (*) = currently enabled") print() print(" Tip: Use 'all' or '*' to enable all toolsets") print(" Example: python cli.py --toolsets web,terminal") print() - + def show_config(self): """Display current configuration with kawaii ASCII art.""" # Get terminal config from environment (which was set from cli-config.yaml) terminal_env = os.getenv("TERMINAL_ENV", "local") terminal_cwd = os.getenv("TERMINAL_CWD", os.getcwd()) terminal_timeout = os.getenv("TERMINAL_TIMEOUT", "60") - - user_config_path = Path.home() / '.hermes' / 'config.yaml' - project_config_path = Path(__file__).parent / 'cli-config.yaml' + + user_config_path = Path.home() / ".hermes" / "config.yaml" + project_config_path = Path(__file__).parent / "cli-config.yaml" if user_config_path.exists(): config_path = user_config_path else: config_path = project_config_path config_status = "(loaded)" if config_path.exists() else "(not found)" - - api_key_display = '********' + self.api_key[-4:] if self.api_key and len(self.api_key) > 4 else 'Not set!' - + + api_key_display = "********" + self.api_key[-4:] if self.api_key and len(self.api_key) > 4 else "Not set!" + print() title = "(^_^) Configuration" width = 50 @@ -1931,7 +1977,7 @@ class HermesCLI: print(f" Started: {self.session_start.strftime('%Y-%m-%d %H:%M:%S')}") print(f" Config File: {config_path} {config_status}") print() - + def show_history(self): """Display conversation history.""" if not self.conversation_history: @@ -1975,9 +2021,7 @@ class HermesCLI: if role == "user": print(f"\n [You #{visible_index}]") - print( - f" {content_text[:preview_limit]}{'...' if len(content_text) > preview_limit else ''}" - ) + print(f" {content_text[:preview_limit]}{'...' if len(content_text) > preview_limit else ''}") continue print(f"\n [Hermes #{visible_index}]") @@ -1997,7 +2041,7 @@ class HermesCLI: flush_tool_summary() print() - + def reset_conversation(self): """Reset the conversation history.""" if self.agent and self.conversation_history: @@ -2007,30 +2051,35 @@ class HermesCLI: pass self.conversation_history = [] print("(^_^)b Conversation reset!") - + def save_conversation(self): """Save the current conversation to a file.""" if not self.conversation_history: print("(;_;) No conversation to save.") return - + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"hermes_conversation_{timestamp}.json" - + try: with open(filename, "w", encoding="utf-8") as f: - json.dump({ - "model": self.model, - "session_start": self.session_start.isoformat(), - "messages": self.conversation_history, - }, f, indent=2, ensure_ascii=False) + json.dump( + { + "model": self.model, + "session_start": self.session_start.isoformat(), + "messages": self.conversation_history, + }, + f, + indent=2, + ensure_ascii=False, + ) print(f"(^_^)v Conversation saved to: {filename}") except Exception as e: print(f"(x_x) Failed to save: {e}") - + def retry_last(self): """Retry the last user message by removing the last exchange and re-sending. - + Removes the last assistant response (and any tool-call messages) and the last user message, then re-sends that user message to the agent. Returns the message to re-send, or None if there's nothing to retry. @@ -2038,65 +2087,67 @@ class HermesCLI: if not self.conversation_history: print("(._.) No messages to retry.") return None - + # Walk backwards to find the last user message last_user_idx = None for i in range(len(self.conversation_history) - 1, -1, -1): if self.conversation_history[i].get("role") == "user": last_user_idx = i break - + if last_user_idx is None: print("(._.) No user message found to retry.") return None - + # Extract the message text and remove everything from that point forward last_message = self.conversation_history[last_user_idx].get("content", "") self.conversation_history = self.conversation_history[:last_user_idx] - - print(f"(^_^)b Retrying: \"{last_message[:60]}{'...' if len(last_message) > 60 else ''}\"") + + print(f'(^_^)b Retrying: "{last_message[:60]}{"..." if len(last_message) > 60 else ""}"') return last_message - + def undo_last(self): """Remove the last user/assistant exchange from conversation history. - + Walks backwards and removes all messages from the last user message onward (including assistant responses, tool calls, etc.). """ if not self.conversation_history: print("(._.) No messages to undo.") return - + # Walk backwards to find the last user message last_user_idx = None for i in range(len(self.conversation_history) - 1, -1, -1): if self.conversation_history[i].get("role") == "user": last_user_idx = i break - + if last_user_idx is None: print("(._.) No user message found to undo.") return - + # Count how many messages we're removing removed_count = len(self.conversation_history) - last_user_idx removed_msg = self.conversation_history[last_user_idx].get("content", "") - + # Truncate history to before the last user message self.conversation_history = self.conversation_history[:last_user_idx] - - print(f"(^_^)b Undid {removed_count} message(s). Removed: \"{removed_msg[:60]}{'...' if len(removed_msg) > 60 else ''}\"") + + print( + f'(^_^)b Undid {removed_count} message(s). Removed: "{removed_msg[:60]}{"..." if len(removed_msg) > 60 else ""}"' + ) remaining = len(self.conversation_history) print(f" {remaining} message(s) remaining in history.") - + def _handle_prompt_command(self, cmd: str): """Handle the /prompt command to view or set system prompt.""" parts = cmd.split(maxsplit=1) - + if len(parts) > 1: # Set new prompt new_prompt = parts[1].strip() - + if new_prompt.lower() == "clear": self.system_prompt = "" self.agent = None # Force re-init @@ -2108,10 +2159,10 @@ class HermesCLI: self.system_prompt = new_prompt self.agent = None # Force re-init if save_config_value("agent.system_prompt", new_prompt): - print(f"(^_^)b System prompt set (saved to config)") + print("(^_^)b System prompt set (saved to config)") else: - print(f"(^_^) System prompt set (session only)") - print(f" \"{new_prompt[:60]}{'...' if len(new_prompt) > 60 else ''}\"") + print("(^_^) System prompt set (session only)") + print(f' "{new_prompt[:60]}{"..." if len(new_prompt) > 60 else ""}"') else: # Show current prompt print() @@ -2142,15 +2193,15 @@ class HermesCLI: print(" /prompt clear - Remove custom prompt") print(" /personality - Use a predefined personality") print() - + def _handle_personality_command(self, cmd: str): """Handle the /personality command to set predefined personalities.""" parts = cmd.split(maxsplit=1) - + if len(parts) > 1: # Set personality personality_name = parts[1].strip().lower() - + if personality_name in self.personalities: self.system_prompt = self.personalities[personality_name] self.agent = None # Force re-init @@ -2158,7 +2209,7 @@ class HermesCLI: print(f"(^_^)b Personality set to '{personality_name}' (saved to config)") else: print(f"(^_^) Personality set to '{personality_name}' (session only)") - print(f" \"{self.system_prompt[:60]}{'...' if len(self.system_prompt) > 60 else ''}\"") + print(f' "{self.system_prompt[:60]}{"..." if len(self.system_prompt) > 60 else ""}"') else: print(f"(._.) Unknown personality: {personality_name}") print(f" Available: {', '.join(self.personalities.keys())}") @@ -2170,15 +2221,15 @@ class HermesCLI: print("+" + "-" * 50 + "+") print() for name, prompt in self.personalities.items(): - print(f" {name:<12} - \"{prompt}\"") + print(f' {name:<12} - "{prompt}"') print() print(" Usage: /personality ") print() - + def _handle_cron_command(self, cmd: str): """Handle the /cron command to manage scheduled tasks.""" parts = cmd.split(maxsplit=2) - + if len(parts) == 1: # /cron - show help and list print() @@ -2189,7 +2240,7 @@ class HermesCLI: print(" Commands:") print(" /cron - List scheduled jobs") print(" /cron list - List scheduled jobs") - print(' /cron add - Add a new job') + print(" /cron add - Add a new job") print(" /cron remove - Remove a job") print() print(" Schedule formats:") @@ -2197,7 +2248,7 @@ class HermesCLI: print(' "every 30m", "every 2h" - Recurring interval') print(' "0 9 * * *" - Cron expression') print() - + # Show current jobs jobs = list_jobs() if jobs: @@ -2211,12 +2262,13 @@ class HermesCLI: repeat_str = "forever" else: repeat_str = f"{completed}/{times}" - + print(f" {job['id'][:12]:<12} | {job['schedule_display']:<15} | {repeat_str:<8}") - prompt_preview = job['prompt'][:45] + "..." if len(job['prompt']) > 45 else job['prompt'] + prompt_preview = job["prompt"][:45] + "..." if len(job["prompt"]) > 45 else job["prompt"] print(f" {prompt_preview}") if job.get("next_run_at"): from datetime import datetime + next_run = datetime.fromisoformat(job["next_run_at"]) print(f" Next: {next_run.strftime('%Y-%m-%d %H:%M')}") print() @@ -2224,16 +2276,16 @@ class HermesCLI: print(" No scheduled jobs. Use '/cron add' to create one.") print() return - + subcommand = parts[1].lower() - + if subcommand == "list": # /cron list - just show jobs jobs = list_jobs() if not jobs: print("(._.) No scheduled jobs.") return - + print() print("Scheduled Jobs:") print("-" * 70) @@ -2241,7 +2293,7 @@ class HermesCLI: times = job["repeat"].get("times") completed = job["repeat"].get("completed", 0) repeat_str = "forever" if times is None else f"{completed}/{times}" - + print(f" ID: {job['id']}") print(f" Name: {job['name']}") print(f" Schedule: {job['schedule_display']} ({repeat_str})") @@ -2250,7 +2302,7 @@ class HermesCLI: if job.get("last_run_at"): print(f" Last run: {job['last_run_at']} ({job.get('last_status', '?')})") print() - + elif subcommand == "add": # /cron add if len(parts) < 3: @@ -2258,10 +2310,10 @@ class HermesCLI: print(" Example: /cron add 30m Remind me to take a break") print(' Example: /cron add "every 2h" Check server status at 192.168.1.1') return - + # Parse schedule and prompt rest = parts[2].strip() - + # Handle quoted schedule (e.g., "every 30m" or "0 9 * * *") if rest.startswith('"'): # Find closing quote @@ -2270,17 +2322,17 @@ class HermesCLI: print("(._.) Unmatched quote in schedule") return schedule = rest[1:close_quote] - prompt = rest[close_quote + 1:].strip() + prompt = rest[close_quote + 1 :].strip() else: # First word is schedule schedule_parts = rest.split(maxsplit=1) schedule = schedule_parts[0] prompt = schedule_parts[1] if len(schedule_parts) > 1 else "" - + if not prompt: print("(._.) Please provide a prompt for the job") return - + try: job = create_job(prompt=prompt, schedule=schedule) print(f"(^_^)b Created job: {job['id']}") @@ -2288,57 +2340,58 @@ class HermesCLI: print(f" Next run: {job['next_run_at']}") except Exception as e: print(f"(x_x) Failed to create job: {e}") - + elif subcommand == "remove" or subcommand == "rm" or subcommand == "delete": # /cron remove if len(parts) < 3: print("(._.) Usage: /cron remove ") return - + job_id = parts[2].strip() job = get_job(job_id) - + if not job: print(f"(._.) Job not found: {job_id}") return - + if remove_job(job_id): print(f"(^_^)b Removed job: {job['name']} ({job_id})") else: print(f"(x_x) Failed to remove job: {job_id}") - + else: print(f"(._.) Unknown cron command: {subcommand}") print(" Available: list, add, remove") - + def _handle_skills_command(self, cmd: str): """Handle /skills slash command — delegates to hermes_cli.skills_hub.""" from hermes_cli.skills_hub import handle_skills_slash + handle_skills_slash(cmd, ChatConsole()) def _show_gateway_status(self): """Show status of the gateway and connected messaging platforms.""" - from gateway.config import load_gateway_config, Platform - + from gateway.config import Platform, load_gateway_config + print() print("+" + "-" * 60 + "+") print("|" + " " * 15 + "(✿◠‿◠) Gateway Status" + " " * 17 + "|") print("+" + "-" * 60 + "+") print() - + try: config = load_gateway_config() connected = config.get_connected_platforms() - + print(" Messaging Platform Configuration:") print(" " + "-" * 55) - + platform_status = { Platform.TELEGRAM: ("Telegram", "TELEGRAM_BOT_TOKEN"), Platform.DISCORD: ("Discord", "DISCORD_BOT_TOKEN"), Platform.WHATSAPP: ("WhatsApp", "WHATSAPP_ENABLED"), } - + for platform, (name, env_var) in platform_status.items(): pconfig = config.platforms.get(platform) if pconfig and pconfig.enabled: @@ -2347,7 +2400,7 @@ class HermesCLI: print(f" ✓ {name:<12} Enabled{home_str}") else: print(f" ○ {name:<12} Not configured ({env_var})") - + print() print(" Session Reset Policy:") print(" " + "-" * 55) @@ -2355,14 +2408,14 @@ class HermesCLI: print(f" Mode: {policy.mode}") print(f" Daily reset at: {policy.at_hour}:00") print(f" Idle timeout: {policy.idle_minutes} minutes") - + print() print(" To start the gateway:") print(" python cli.py --gateway") print() print(" Configuration file: ~/.hermes/gateway.json") print() - + except Exception as e: print(f" Error loading gateway config: {e}") print() @@ -2372,21 +2425,21 @@ class HermesCLI: print(" DISCORD_BOT_TOKEN=your_token") print(" 2. Or create ~/.hermes/gateway.json") print() - + def process_command(self, command: str) -> bool: """ Process a slash command. - + Args: command: The command string (starting with /) - + Returns: bool: True to continue, False to exit """ # Lowercase only for dispatch matching; preserve original case for arguments cmd_lower = command.lower().strip() cmd_original = command.strip() - + if cmd_lower in ("/quit", "/exit", "/q"): return False elif cmd_lower == "/help": @@ -2430,7 +2483,7 @@ class HermesCLI: tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) cwd = os.getenv("TERMINAL_CWD", os.getcwd()) ctx_len = None - if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'): + if hasattr(self, "agent") and self.agent and hasattr(self.agent, "context_compressor"): ctx_len = self.agent.context_compressor.context_length build_welcome_banner( console=cc, @@ -2456,6 +2509,7 @@ class HermesCLI: # Sanitize the title early so feedback matches what gets stored try: from hermes_state import SessionDB + new_title = SessionDB.sanitize_title(raw_title) except ValueError as e: _cprint(f" {e}") @@ -2493,7 +2547,7 @@ class HermesCLI: elif self._pending_title: _cprint(f" Session title (pending): {self._pending_title}") else: - _cprint(f" No title set. Usage: /title ") + _cprint(" No title set. Usage: /title ") else: _cprint(" Session database not available.") elif cmd_lower in ("/reset", "/new"): @@ -2502,11 +2556,10 @@ class HermesCLI: # Use original case so model names like "Anthropic/Claude-Opus-4" are preserved parts = cmd_original.split(maxsplit=1) if len(parts) > 1: - from hermes_cli.auth import resolve_provider from hermes_cli.models import ( + _PROVIDER_LABELS, parse_model_input, validate_requested_model, - _PROVIDER_LABELS, ) raw_input = parts[1].strip() @@ -2522,6 +2575,7 @@ class HermesCLI: if provider_changed: try: from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=target_provider) api_key_for_probe = runtime.get("api_key", "") base_url_for_probe = runtime.get("base_url", "") @@ -2574,8 +2628,9 @@ class HermesCLI: print(f" Reason: {message}") print(" Note: Model will revert on restart. Use a verified model to save to config.") else: - from hermes_cli.models import curated_models_for_provider, normalize_provider, _PROVIDER_LABELS from hermes_cli.auth import resolve_provider as _resolve_provider + from hermes_cli.models import _PROVIDER_LABELS, curated_models_for_provider, normalize_provider + # Resolve "auto" to the actual provider using credential detection raw_provider = normalize_provider(self.provider) if raw_provider == "auto": @@ -2606,8 +2661,9 @@ class HermesCLI: print(" Example: /model openrouter:anthropic/claude-sonnet-4.5") print(" See /provider for available providers") elif cmd_lower == "/provider": - from hermes_cli.models import list_available_providers, normalize_provider, _PROVIDER_LABELS from hermes_cli.auth import resolve_provider as _resolve_provider + from hermes_cli.models import _PROVIDER_LABELS, list_available_providers, normalize_provider + # Resolve current provider raw_provider = normalize_provider(self.provider) if raw_provider == "auto": @@ -2641,7 +2697,7 @@ class HermesCLI: self._handle_personality_command(cmd_original) elif cmd_lower == "/retry": retry_msg = self.retry_last() - if retry_msg and hasattr(self, '_pending_input'): + if retry_msg and hasattr(self, "_pending_input"): # Re-queue the message so process_loop sends it to the agent self._pending_input.put(retry_msg) elif cmd_lower == "/undo": @@ -2670,21 +2726,21 @@ class HermesCLI: # Check for skill slash commands (/gif-search, /axolotl, etc.) base_cmd = cmd_lower.split()[0] if base_cmd in _skill_commands: - user_instruction = cmd_original[len(base_cmd):].strip() + user_instruction = cmd_original[len(base_cmd) :].strip() msg = build_skill_invocation_message(base_cmd, user_instruction) if msg: skill_name = _skill_commands[base_cmd]["name"] print(f"\n⚡ Loading skill: {skill_name}") - if hasattr(self, '_pending_input'): + if hasattr(self, "_pending_input"): self._pending_input.put(msg) else: self.console.print(f"[bold red]Failed to load skill for {base_cmd}[/]") else: self.console.print(f"[bold red]Unknown command: {cmd_lower}[/]") self.console.print("[dim #B8860B]Type /help for available commands[/]") - + return True - + def _toggle_verbose(self): """Cycle tool progress mode: off → new → all → verbose → off.""" cycle = ["off", "new", "all", "verbose"] @@ -2724,6 +2780,7 @@ class HermesCLI: original_count = len(self.conversation_history) try: from agent.model_metadata import estimate_messages_tokens_rough + approx_tokens = estimate_messages_tokens_rough(self.conversation_history) print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...") @@ -2767,7 +2824,7 @@ class HermesCLI: msg_count = len(self.conversation_history) - print(f" 📊 Session Token Usage") + print(" 📊 Session Token Usage") print(f" {'─' * 40}") print(f" Prompt tokens (input): {prompt:>10,}") print(f" Completion tokens (output): {completion:>9,}") @@ -2780,11 +2837,11 @@ class HermesCLI: if self.verbose: logging.getLogger().setLevel(logging.DEBUG) - for noisy in ('openai', 'openai._base_client', 'httpx', 'httpcore', 'asyncio', 'hpack', 'grpc', 'modal'): + for noisy in ("openai", "openai._base_client", "httpx", "httpcore", "asyncio", "hpack", "grpc", "modal"): logging.getLogger(noisy).setLevel(logging.WARNING) else: logging.getLogger().setLevel(logging.INFO) - for quiet_logger in ('tools', 'minisweagent', 'run_agent', 'trajectory_compressor', 'cron', 'hermes_cli'): + for quiet_logger in ("tools", "minisweagent", "run_agent", "trajectory_compressor", "cron", "hermes_cli"): logging.getLogger(quiet_logger).setLevel(logging.ERROR) def _show_insights(self, command: str = "/insights"): @@ -2809,8 +2866,8 @@ class HermesCLI: i += 1 try: - from hermes_state import SessionDB from agent.insights import InsightsEngine + from hermes_state import SessionDB db = SessionDB() engine = InsightsEngine(db) @@ -2827,7 +2884,7 @@ class HermesCLI: sees the updated tools on the next turn. """ try: - from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock + from tools.mcp_tool import _lock, _servers, discover_mcp_tools, shutdown_mcp_servers # Capture old server names with _lock: @@ -2863,14 +2920,14 @@ class HermesCLI: # Refresh the agent's tool list so the model can call new tools if self.agent is not None: from model_tools import get_tool_definitions + self.agent.tools = get_tool_definitions( - enabled_toolsets=self.agent.enabled_toolsets - if hasattr(self.agent, "enabled_toolsets") else None, + enabled_toolsets=self.agent.enabled_toolsets if hasattr(self.agent, "enabled_toolsets") else None, quiet_mode=True, ) - self.agent.valid_tool_names = { - tool["function"]["name"] for tool in self.agent.tools - } if self.agent.tools else set() + self.agent.valid_tool_names = ( + {tool["function"]["name"] for tool in self.agent.tools} if self.agent.tools else set() + ) # Inject a message at the END of conversation history so the # model knows tools changed. Appended after all existing @@ -2884,10 +2941,12 @@ class HermesCLI: change_parts.append(f"Reconnected servers: {', '.join(sorted(reconnected))}") tool_summary = f"{len(new_tools)} MCP tool(s) now available" if new_tools else "No MCP tools available" change_detail = ". ".join(change_parts) + ". " if change_parts else "" - self.conversation_history.append({ - "role": "user", - "content": f"[SYSTEM: MCP servers have been reloaded. {change_detail}{tool_summary}. The tool list for this conversation has been updated accordingly.]", - }) + self.conversation_history.append( + { + "role": "user", + "content": f"[SYSTEM: MCP servers have been reloaded. {change_detail}{tool_summary}. The tool list for this conversation has been updated accordingly.]", + } + ) # Persist session immediately so the session log reflects the # updated tools list (self.agent.tools was refreshed above). @@ -2961,7 +3020,7 @@ class HermesCLI: def _sudo_password_callback(self) -> str: """ Prompt for sudo password through the prompt_toolkit UI. - + Called from the agent thread when a sudo command is encountered. Uses the same clarify-style mechanism: sets UI state, waits on a queue for the user's response via the Enter key binding. @@ -3004,7 +3063,7 @@ class HermesCLI: def _approval_callback(self, command: str, description: str) -> str: """ Prompt for dangerous command approval through the prompt_toolkit UI. - + Called from the agent thread. Shows a selection UI similar to clarify with choices: once / session / always / deny. """ @@ -3041,22 +3100,23 @@ class HermesCLI: self._approval_state = None self._approval_deadline = 0 self._invalidate() - def chat(self, message, images: list = None) -> Optional[str]: + + def chat(self, message, images: list = None) -> str | None: """ Send a message to the agent and get a response. - + Handles streaming output, interrupt detection (user typing while agent is working), and re-queueing of interrupted messages. - + Uses a dedicated _interrupt_queue (separate from _pending_input) to avoid race conditions between the process_loop and interrupt monitoring. Messages typed while the agent is running go to _interrupt_queue; messages typed while idle go to _pending_input. - + Args: message: The user's message (str or multimodal content list) images: Optional list of Path objects for attached images - + Returns: The agent's response, or None on error """ @@ -3067,26 +3127,24 @@ class HermesCLI: # Initialize agent if needed if not self._init_agent(): return None - + # Pre-process images through the vision tool (Gemini Flash) so the # main model receives text descriptions instead of raw base64 image # content — works with any model, not just vision-capable ones. if images: - message = self._preprocess_images_with_vision( - message if isinstance(message, str) else "", images - ) + message = self._preprocess_images_with_vision(message if isinstance(message, str) else "", images) # Add user message to history self.conversation_history.append({"role": "user", "content": message}) - + w = shutil.get_terminal_size().columns _cprint(f"{_GOLD}{'─' * w}{_RST}") print(flush=True) - + try: # Run the conversation with interrupt monitoring result = None - + def run_agent(): nonlocal result result = self.agent.run_conversation( @@ -3094,11 +3152,11 @@ class HermesCLI: conversation_history=self.conversation_history[:-1], # Exclude the message we just added task_id=self.session_id, ) - + # Start agent in background thread agent_thread = threading.Thread(target=run_agent) agent_thread.start() - + # Monitor the dedicated interrupt queue while the agent runs. # _interrupt_queue is separate from _pending_input, so process_loop # and chat() never compete for the same queue. @@ -3107,7 +3165,7 @@ class HermesCLI: # so we skip interrupt processing to avoid stealing that input. interrupt_msg = None while agent_thread.is_alive(): - if hasattr(self, '_interrupt_queue'): + if hasattr(self, "_interrupt_queue"): try: interrupt_msg = self._interrupt_queue.get(timeout=0.1) if interrupt_msg: @@ -3116,7 +3174,7 @@ class HermesCLI: # But if it does (race condition), don't interrupt. if self._clarify_state or self._clarify_freetext: continue - print(f"\n⚡ New message detected, interrupting...") + print("\n⚡ New message detected, interrupting...") self.agent.interrupt(interrupt_msg) break except queue.Empty: @@ -3124,7 +3182,7 @@ class HermesCLI: else: # Fallback for non-interactive mode (e.g., single-query) agent_thread.join(0.1) - + agent_thread.join() # Ensure agent thread completes # Drain any remaining agent output still in the StdoutProxy @@ -3132,20 +3190,23 @@ class HermesCLI: # The flush pushes data into the renderer queue; the short # sleep lets the renderer actually paint it before we draw. import time as _time + sys.stdout.flush() _time.sleep(0.15) # Update history with full conversation - self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history - + self.conversation_history = ( + result.get("messages", self.conversation_history) if result else self.conversation_history + ) + # Get the final response response = result.get("final_response", "") if result else "" - + # Handle failed results (e.g., non-retryable errors like invalid model) if result and result.get("failed") and not response: error_detail = result.get("error", "Unknown error") response = f"Error: {error_detail}" - + # Handle interrupt - check if we were interrupted pending_message = None if result and result.get("interrupted"): @@ -3153,7 +3214,7 @@ class HermesCLI: # Add indicator that we were interrupted if response and pending_message: response = response + "\n\n---\n_[Interrupted - processing new message]_" - + if response: w = shutil.get_terminal_size().columns label = " ⚕ Hermes " @@ -3164,16 +3225,16 @@ class HermesCLI: # Render box + response as a single _cprint call so # nothing can interleave between the box borders. _cprint(f"\n{top}\n{response}\n\n{bot}") - + # Play terminal bell when agent finishes (if enabled). # Works over SSH — the bell propagates to the user's terminal. if self.bell_on_complete: sys.stdout.write("\a") sys.stdout.flush() - + # Combine all interrupt messages (user may have typed multiple while waiting) # and re-queue as one prompt for process_loop - if pending_message and hasattr(self, '_pending_input'): + if pending_message and hasattr(self, "_pending_input"): all_parts = [pending_message] while not self._interrupt_queue.empty(): try: @@ -3185,13 +3246,13 @@ class HermesCLI: combined = "\n".join(all_parts) print(f"\n📨 Queued: '{combined[:50]}{'...' if len(combined) > 50 else ''}'") self._pending_input.put(combined) - + return response - + except Exception as e: print(f"Error: {e}") return None - + def _print_exit_summary(self): """Print session resume info on exit, similar to Claude Code.""" print() @@ -3208,8 +3269,8 @@ class HermesCLI: duration_str = f"{minutes}m {seconds}s" else: duration_str = f"{seconds}s" - - print(f"Resume this session with:") + + print("Resume this session with:") print(f" hermes --resume {self.session_id}") print() print(f"Session: {self.session_id}") @@ -3230,27 +3291,27 @@ class HermesCLI: self.console.print("[#FFF8DC]Welcome to Hermes Agent! Type your message or /help for commands.[/]") self.console.print() - + # State for async operation self._agent_running = False - self._pending_input = queue.Queue() # For normal input (commands + new queries) - self._interrupt_queue = queue.Queue() # For messages typed while agent is running + self._pending_input = queue.Queue() # For normal input (commands + new queries) + self._interrupt_queue = queue.Queue() # For messages typed while agent is running self._should_exit = False self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit # Clarify tool state: interactive question/answer with the user. # When the agent calls the clarify tool, _clarify_state is set and # the prompt_toolkit UI switches to a selection mode. - self._clarify_state = None # dict with question, choices, selected, response_queue + self._clarify_state = None # dict with question, choices, selected, response_queue self._clarify_freetext = False # True when user chose "Other" and is typing - self._clarify_deadline = 0 # monotonic timestamp when the clarify times out + self._clarify_deadline = 0 # monotonic timestamp when the clarify times out # Sudo password prompt state (similar mechanism to clarify) - self._sudo_state = None # dict with response_queue when active + self._sudo_state = None # dict with response_queue when active self._sudo_deadline = 0 # Dangerous command approval state (similar mechanism to clarify) - self._approval_state = None # dict with command, description, choices, selected, response_queue + self._approval_state = None # dict with command, description, choices, selected, response_queue self._approval_deadline = 0 # Clipboard image attachments (paste images into the CLI) @@ -3260,14 +3321,14 @@ class HermesCLI: # Register callbacks so terminal_tool prompts route through our UI set_sudo_password_callback(self._sudo_password_callback) set_approval_callback(self._approval_callback) - + # Key bindings for the input area kb = KeyBindings() - - @kb.add('enter') + + @kb.add("enter") def handle_enter(event): """Handle Enter key - submit input. - + Routes to the correct queue based on active UI state: - Sudo password prompt: password goes to sudo response queue - Approval selection: selected choice goes to approval response queue @@ -3339,27 +3400,27 @@ class HermesCLI: else: self._pending_input.put(payload) event.app.current_buffer.reset(append_to_history=True) - - @kb.add('escape', 'enter') + + @kb.add("escape", "enter") def handle_alt_enter(event): """Alt+Enter inserts a newline for multi-line input.""" - event.current_buffer.insert_text('\n') + event.current_buffer.insert_text("\n") - @kb.add('c-j') + @kb.add("c-j") def handle_ctrl_enter(event): """Ctrl+Enter (c-j) inserts a newline. Most terminals send c-j for Ctrl+Enter.""" - event.current_buffer.insert_text('\n') + event.current_buffer.insert_text("\n") # --- Clarify tool: arrow-key navigation for multiple-choice questions --- - @kb.add('up', filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) + @kb.add("up", filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) def clarify_up(event): """Move selection up in clarify choices.""" if self._clarify_state: self._clarify_state["selected"] = max(0, self._clarify_state["selected"] - 1) event.app.invalidate() - @kb.add('down', filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) + @kb.add("down", filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) def clarify_down(event): """Move selection down in clarify choices.""" if self._clarify_state: @@ -3370,13 +3431,13 @@ class HermesCLI: # --- Dangerous command approval: arrow-key navigation --- - @kb.add('up', filter=Condition(lambda: bool(self._approval_state))) + @kb.add("up", filter=Condition(lambda: bool(self._approval_state))) def approval_up(event): if self._approval_state: self._approval_state["selected"] = max(0, self._approval_state["selected"] - 1) event.app.invalidate() - @kb.add('down', filter=Condition(lambda: bool(self._approval_state))) + @kb.add("down", filter=Condition(lambda: bool(self._approval_state))) def approval_down(event): if self._approval_state: max_idx = len(self._approval_state["choices"]) - 1 @@ -3387,30 +3448,29 @@ class HermesCLI: # The TextArea is multiline, so by default up/down only move the cursor. # Buffer.auto_up/auto_down handle both: cursor movement when multi-line, # history browsing when on the first/last line (or single-line input). - _normal_input = Condition( - lambda: not self._clarify_state and not self._approval_state and not self._sudo_state - ) + _normal_input = Condition(lambda: not self._clarify_state and not self._approval_state and not self._sudo_state) - @kb.add('up', filter=_normal_input) + @kb.add("up", filter=_normal_input) def history_up(event): """Up arrow: browse history when on first line, else move cursor up.""" event.app.current_buffer.auto_up(count=event.arg) - @kb.add('down', filter=_normal_input) + @kb.add("down", filter=_normal_input) def history_down(event): """Down arrow: browse history when on last line, else move cursor down.""" event.app.current_buffer.auto_down(count=event.arg) - @kb.add('c-c') + @kb.add("c-c") def handle_ctrl_c(event): """Handle Ctrl+C - cancel interactive prompts, interrupt agent, or exit. - + Priority: 1. Cancel active sudo/approval/clarify prompt 2. Interrupt the running agent (first press) 3. Force exit (second press within 2s, or when idle) """ import time as _time + now = _time.time() # Cancel sudo prompt @@ -3430,9 +3490,7 @@ class HermesCLI: # Cancel clarify prompt if self._clarify_state: - self._clarify_state["response_queue"].put( - "The user cancelled. Use your best judgement to proceed." - ) + self._clarify_state["response_queue"].put("The user cancelled. Use your best judgement to proceed.") self._clarify_state = None self._clarify_freetext = False event.app.current_buffer.reset() @@ -3445,7 +3503,7 @@ class HermesCLI: self._should_exit = True event.app.exit() return - + self._last_ctrl_c_time = now print("\n⚡ Interrupting agent... (press Ctrl+C again to force exit)") self.agent.interrupt() @@ -3459,8 +3517,8 @@ class HermesCLI: else: self._should_exit = True event.app.exit() - - @kb.add('c-d') + + @kb.add("c-d") def handle_ctrl_d(event): """Handle Ctrl+D - exit.""" self._should_exit = True @@ -3482,7 +3540,7 @@ class HermesCLI: if pasted_text: event.current_buffer.insert_text(pasted_text) - @kb.add('c-v') + @kb.add("c-v") def handle_ctrl_v(event): """Fallback image paste for terminals without bracketed paste. @@ -3496,7 +3554,7 @@ class HermesCLI: if self._try_attach_clipboard_image(): event.app.invalidate() - @kb.add('escape', 'v') + @kb.add("escape", "v") def handle_alt_v(event): """Alt+V — paste image from clipboard. @@ -3518,22 +3576,22 @@ class HermesCLI: def get_prompt(): if cli_ref._sudo_state: - return [('class:sudo-prompt', '🔐 ❯ ')] + return [("class:sudo-prompt", "🔐 ❯ ")] if cli_ref._approval_state: - return [('class:prompt-working', '⚠ ❯ ')] + return [("class:prompt-working", "⚠ ❯ ")] if cli_ref._clarify_freetext: - return [('class:clarify-selected', '✎ ❯ ')] + return [("class:clarify-selected", "✎ ❯ ")] if cli_ref._clarify_state: - return [('class:prompt-working', '? ❯ ')] + return [("class:prompt-working", "? ❯ ")] if cli_ref._agent_running: - return [('class:prompt-working', '⚕ ❯ ')] - return [('class:prompt', '❯ ')] + return [("class:prompt-working", "⚕ ❯ ")] + return [("class:prompt", "❯ ")] # Create the input area with multiline (shift+enter), autocomplete, and paste handling input_area = TextArea( height=Dimension(min=1, max=8, preferred=1), prompt=get_prompt, - style='class:input-area', + style="class:input-area", multiline=True, wrap_lines=True, history=FileHistory(str(self._history_file)), @@ -3570,12 +3628,12 @@ class HermesCLI: def _on_text_changed(buf): """Detect large pastes and collapse them to a file reference.""" text = buf.text - line_count = text.count('\n') + line_count = text.count("\n") chars_added = len(text) - _prev_text_len[0] _prev_text_len[0] = len(text) # Heuristic: a real paste adds many characters at once (not just a # single newline from Alt+Enter) AND the result has 5+ lines. - if line_count >= 5 and chars_added > 1 and not text.startswith('/'): + if line_count >= 5 and chars_added > 1 and not text.startswith("/"): _paste_counter[0] += 1 # Save to temp file paste_dir = Path(os.path.expanduser("~/.hermes/pastes")) @@ -3600,6 +3658,7 @@ class HermesCLI: class _PlaceholderProcessor(Processor): """Render grayed-out placeholder text inside the input when empty.""" + def __init__(self, get_text): self._get_text = get_text @@ -3608,7 +3667,7 @@ class HermesCLI: text = self._get_text() if text: # Append after existing fragments (preserves the ❯ prompt) - return Transformation(fragments=ti.fragments + [('class:placeholder', text)]) + return Transformation(fragments=ti.fragments + [("class:placeholder", text)]) return Transformation(fragments=ti.fragments) def _get_placeholder(): @@ -3633,28 +3692,28 @@ class HermesCLI: if cli_ref._sudo_state: remaining = max(0, int(cli_ref._sudo_deadline - _time.monotonic())) return [ - ('class:hint', ' password hidden · Enter to skip'), - ('class:clarify-countdown', f' ({remaining}s)'), + ("class:hint", " password hidden · Enter to skip"), + ("class:clarify-countdown", f" ({remaining}s)"), ] if cli_ref._approval_state: remaining = max(0, int(cli_ref._approval_deadline - _time.monotonic())) return [ - ('class:hint', ' ↑/↓ to select, Enter to confirm'), - ('class:clarify-countdown', f' ({remaining}s)'), + ("class:hint", " ↑/↓ to select, Enter to confirm"), + ("class:clarify-countdown", f" ({remaining}s)"), ] if cli_ref._clarify_state: remaining = max(0, int(cli_ref._clarify_deadline - _time.monotonic())) - countdown = f' ({remaining}s)' if cli_ref._clarify_deadline else '' + countdown = f" ({remaining}s)" if cli_ref._clarify_deadline else "" if cli_ref._clarify_freetext: return [ - ('class:hint', ' type your answer and press Enter'), - ('class:clarify-countdown', countdown), + ("class:hint", " type your answer and press Enter"), + ("class:clarify-countdown", countdown), ] return [ - ('class:hint', ' ↑/↓ to select, Enter to confirm'), - ('class:clarify-countdown', countdown), + ("class:hint", " ↑/↓ to select, Enter to confirm"), + ("class:clarify-countdown", countdown), ] return [] @@ -3685,40 +3744,40 @@ class HermesCLI: lines = [] # Box top border - lines.append(('class:clarify-border', '╭─ ')) - lines.append(('class:clarify-title', 'Hermes needs your input')) - lines.append(('class:clarify-border', ' ─────────────────────────────╮\n')) - lines.append(('class:clarify-border', '│\n')) + lines.append(("class:clarify-border", "╭─ ")) + lines.append(("class:clarify-title", "Hermes needs your input")) + lines.append(("class:clarify-border", " ─────────────────────────────╮\n")) + lines.append(("class:clarify-border", "│\n")) # Question text - lines.append(('class:clarify-border', '│ ')) - lines.append(('class:clarify-question', question)) - lines.append(('', '\n')) - lines.append(('class:clarify-border', '│\n')) + lines.append(("class:clarify-border", "│ ")) + lines.append(("class:clarify-question", question)) + lines.append(("", "\n")) + lines.append(("class:clarify-border", "│\n")) if choices: # Multiple-choice mode: show selectable options for i, choice in enumerate(choices): - lines.append(('class:clarify-border', '│ ')) + lines.append(("class:clarify-border", "│ ")) if i == selected and not cli_ref._clarify_freetext: - lines.append(('class:clarify-selected', f'❯ {choice}')) + lines.append(("class:clarify-selected", f"❯ {choice}")) else: - lines.append(('class:clarify-choice', f' {choice}')) - lines.append(('', '\n')) + lines.append(("class:clarify-choice", f" {choice}")) + lines.append(("", "\n")) # "Other" option (5th line, only shown when choices exist) other_idx = len(choices) - lines.append(('class:clarify-border', '│ ')) + lines.append(("class:clarify-border", "│ ")) if selected == other_idx and not cli_ref._clarify_freetext: - lines.append(('class:clarify-selected', '❯ Other (type your answer)')) + lines.append(("class:clarify-selected", "❯ Other (type your answer)")) elif cli_ref._clarify_freetext: - lines.append(('class:clarify-active-other', '❯ Other (type below)')) + lines.append(("class:clarify-active-other", "❯ Other (type below)")) else: - lines.append(('class:clarify-choice', ' Other (type your answer)')) - lines.append(('', '\n')) + lines.append(("class:clarify-choice", " Other (type your answer)")) + lines.append(("", "\n")) - lines.append(('class:clarify-border', '│\n')) - lines.append(('class:clarify-border', '╰──────────────────────────────────────────────────╯\n')) + lines.append(("class:clarify-border", "│\n")) + lines.append(("class:clarify-border", "╰──────────────────────────────────────────────────╯\n")) return lines clarify_widget = ConditionalContainer( @@ -3736,15 +3795,15 @@ class HermesCLI: if not state: return [] lines = [] - lines.append(('class:sudo-border', '╭─ ')) - lines.append(('class:sudo-title', '🔐 Sudo Password Required')) - lines.append(('class:sudo-border', ' ──────────────────────────╮\n')) - lines.append(('class:sudo-border', '│\n')) - lines.append(('class:sudo-border', '│ ')) - lines.append(('class:sudo-text', 'Enter password below (hidden), or press Enter to skip')) - lines.append(('', '\n')) - lines.append(('class:sudo-border', '│\n')) - lines.append(('class:sudo-border', '╰──────────────────────────────────────────────────╯\n')) + lines.append(("class:sudo-border", "╭─ ")) + lines.append(("class:sudo-title", "🔐 Sudo Password Required")) + lines.append(("class:sudo-border", " ──────────────────────────╮\n")) + lines.append(("class:sudo-border", "│\n")) + lines.append(("class:sudo-border", "│ ")) + lines.append(("class:sudo-text", "Enter password below (hidden), or press Enter to skip")) + lines.append(("", "\n")) + lines.append(("class:sudo-border", "│\n")) + lines.append(("class:sudo-border", "╰──────────────────────────────────────────────────╯\n")) return lines sudo_widget = ConditionalContainer( @@ -3766,7 +3825,7 @@ class HermesCLI: choices = state["choices"] selected = state.get("selected", 0) - cmd_display = command[:70] + '...' if len(command) > 70 else command + cmd_display = command[:70] + "..." if len(command) > 70 else command choice_labels = { "once": "Allow once", "session": "Allow for this session", @@ -3775,27 +3834,27 @@ class HermesCLI: } lines = [] - lines.append(('class:approval-border', '╭─ ')) - lines.append(('class:approval-title', '⚠️ Dangerous Command')) - lines.append(('class:approval-border', ' ───────────────────────────────╮\n')) - lines.append(('class:approval-border', '│\n')) - lines.append(('class:approval-border', '│ ')) - lines.append(('class:approval-desc', description)) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│ ')) - lines.append(('class:approval-cmd', cmd_display)) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│\n')) + lines.append(("class:approval-border", "╭─ ")) + lines.append(("class:approval-title", "⚠️ Dangerous Command")) + lines.append(("class:approval-border", " ───────────────────────────────╮\n")) + lines.append(("class:approval-border", "│\n")) + lines.append(("class:approval-border", "│ ")) + lines.append(("class:approval-desc", description)) + lines.append(("", "\n")) + lines.append(("class:approval-border", "│ ")) + lines.append(("class:approval-cmd", cmd_display)) + lines.append(("", "\n")) + lines.append(("class:approval-border", "│\n")) for i, choice in enumerate(choices): - lines.append(('class:approval-border', '│ ')) + lines.append(("class:approval-border", "│ ")) label = choice_labels.get(choice, choice) if i == selected: - lines.append(('class:approval-selected', f'❯ {label}')) + lines.append(("class:approval-selected", f"❯ {label}")) else: - lines.append(('class:approval-choice', f' {label}')) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│\n')) - lines.append(('class:approval-border', '╰──────────────────────────────────────────────────────╯\n')) + lines.append(("class:approval-choice", f" {label}")) + lines.append(("", "\n")) + lines.append(("class:approval-border", "│\n")) + lines.append(("class:approval-border", "╰──────────────────────────────────────────────────────╯\n")) return lines approval_widget = ConditionalContainer( @@ -3811,14 +3870,14 @@ class HermesCLI: # Using char='─' instead of hardcoded repetition so the rule # always spans the full terminal width on any screen size. input_rule_top = Window( - char='─', + char="─", height=1, - style='class:input-rule', + style="class:input-rule", ) input_rule_bot = Window( - char='─', + char="─", height=1, - style='class:input-rule', + style="class:input-rule", ) # Image attachment indicator — shows badges like [📎 Image #1] above input @@ -3828,10 +3887,7 @@ class HermesCLI: if not cli_ref._attached_images: return [] base = cli_ref._image_counter - len(cli_ref._attached_images) + 1 - badges = " ".join( - f"[📎 Image #{base + i}]" - for i in range(len(cli_ref._attached_images)) - ) + badges = " ".join(f"[📎 Image #{base + i}]" for i in range(len(cli_ref._attached_images))) return [("class:image-badge", f" {badges} ")] image_bar = Window( @@ -3843,58 +3899,62 @@ class HermesCLI: # The sudo, approval, and clarify widgets appear above the input when # the corresponding interactive prompt is active. layout = Layout( - HSplit([ - Window(height=0), - sudo_widget, - approval_widget, - clarify_widget, - spacer, - input_rule_top, - image_bar, - input_area, - input_rule_bot, - CompletionsMenu(max_height=12, scroll_offset=1), - ]) + HSplit( + [ + Window(height=0), + sudo_widget, + approval_widget, + clarify_widget, + spacer, + input_rule_top, + image_bar, + input_area, + input_rule_bot, + CompletionsMenu(max_height=12, scroll_offset=1), + ] + ) ) - + # Style for the application - style = PTStyle.from_dict({ - 'input-area': '#FFF8DC', - 'placeholder': '#555555 italic', - 'prompt': '#FFF8DC', - 'prompt-working': '#888888 italic', - 'hint': '#555555 italic', - # Bronze horizontal rules around the input area - 'input-rule': '#CD7F32', - # Clipboard image attachment badges - 'image-badge': '#87CEEB bold', - 'completion-menu': 'bg:#1a1a2e #FFF8DC', - 'completion-menu.completion': 'bg:#1a1a2e #FFF8DC', - 'completion-menu.completion.current': 'bg:#333355 #FFD700', - 'completion-menu.meta.completion': 'bg:#1a1a2e #888888', - 'completion-menu.meta.completion.current': 'bg:#333355 #FFBF00', - # Clarify question panel - 'clarify-border': '#CD7F32', - 'clarify-title': '#FFD700 bold', - 'clarify-question': '#FFF8DC bold', - 'clarify-choice': '#AAAAAA', - 'clarify-selected': '#FFD700 bold', - 'clarify-active-other': '#FFD700 italic', - 'clarify-countdown': '#CD7F32', - # Sudo password panel - 'sudo-prompt': '#FF6B6B bold', - 'sudo-border': '#CD7F32', - 'sudo-title': '#FF6B6B bold', - 'sudo-text': '#FFF8DC', - # Dangerous command approval panel - 'approval-border': '#CD7F32', - 'approval-title': '#FF8C00 bold', - 'approval-desc': '#FFF8DC bold', - 'approval-cmd': '#AAAAAA italic', - 'approval-choice': '#AAAAAA', - 'approval-selected': '#FFD700 bold', - }) - + style = PTStyle.from_dict( + { + "input-area": "#FFF8DC", + "placeholder": "#555555 italic", + "prompt": "#FFF8DC", + "prompt-working": "#888888 italic", + "hint": "#555555 italic", + # Bronze horizontal rules around the input area + "input-rule": "#CD7F32", + # Clipboard image attachment badges + "image-badge": "#87CEEB bold", + "completion-menu": "bg:#1a1a2e #FFF8DC", + "completion-menu.completion": "bg:#1a1a2e #FFF8DC", + "completion-menu.completion.current": "bg:#333355 #FFD700", + "completion-menu.meta.completion": "bg:#1a1a2e #888888", + "completion-menu.meta.completion.current": "bg:#333355 #FFBF00", + # Clarify question panel + "clarify-border": "#CD7F32", + "clarify-title": "#FFD700 bold", + "clarify-question": "#FFF8DC bold", + "clarify-choice": "#AAAAAA", + "clarify-selected": "#FFD700 bold", + "clarify-active-other": "#FFD700 italic", + "clarify-countdown": "#CD7F32", + # Sudo password panel + "sudo-prompt": "#FF6B6B bold", + "sudo-border": "#CD7F32", + "sudo-title": "#FF6B6B bold", + "sudo-text": "#FFF8DC", + # Dangerous command approval panel + "approval-border": "#CD7F32", + "approval-title": "#FF8C00 bold", + "approval-desc": "#FFF8DC bold", + "approval-cmd": "#AAAAAA italic", + "approval-choice": "#AAAAAA", + "approval-selected": "#FFD700 bold", + } + ) + # Create the application app = Application( layout=layout, @@ -3904,7 +3964,7 @@ class HermesCLI: mouse_support=False, ) self._app = app # Store reference for clarify_callback - + # Background thread to process inputs and run agent def process_loop(): while not self._should_exit: @@ -3914,7 +3974,7 @@ class HermesCLI: user_input = self._pending_input.get(timeout=0.1) except queue.Empty: continue - + if not user_input: continue @@ -3922,7 +3982,7 @@ class HermesCLI: submit_images = [] if isinstance(user_input, tuple): user_input, submit_images = user_input - + # Check for commands if isinstance(user_input, str) and user_input.startswith("/"): print(f"\n⚙️ {user_input}") @@ -3932,15 +3992,20 @@ class HermesCLI: if app.is_running: app.exit() continue - + # Expand paste references back to full content import re as _re - paste_match = _re.match(r'\[Pasted text #\d+: \d+ lines → (.+)\]', user_input) if isinstance(user_input, str) else None + + paste_match = ( + _re.match(r"\[Pasted text #\d+: \d+ lines → (.+)\]", user_input) + if isinstance(user_input, str) + else None + ) if paste_match: paste_path = Path(paste_match.group(1)) if paste_path.exists(): full_text = paste_path.read_text(encoding="utf-8") - line_count = full_text.count('\n') + 1 + line_count = full_text.count("\n") + 1 print() _cprint(f"{_GOLD}●{_RST} {_BOLD}[Pasted text: {line_count} lines]{_RST}") user_input = full_text @@ -3948,15 +4013,15 @@ class HermesCLI: print() _cprint(f"{_GOLD}●{_RST} {_BOLD}{user_input}{_RST}") else: - if '\n' in user_input: - first_line = user_input.split('\n')[0] - line_count = user_input.count('\n') + 1 + if "\n" in user_input: + first_line = user_input.split("\n")[0] + line_count = user_input.count("\n") + 1 print() _cprint(f"{_GOLD}●{_RST} {_BOLD}{first_line}{_RST} {_DIM}(+{line_count - 1} lines){_RST}") else: print() _cprint(f"{_GOLD}●{_RST} {_BOLD}{user_input}{_RST}") - + # Show image attachment count if submit_images: n = len(submit_images) @@ -3965,23 +4030,23 @@ class HermesCLI: # Regular chat - run agent self._agent_running = True app.invalidate() # Refresh status line - + try: self.chat(user_input, images=submit_images or None) finally: self._agent_running = False app.invalidate() # Refresh status line - + except Exception as e: print(f"Error: {e}") - + # Start processing thread process_thread = threading.Thread(target=process_loop, daemon=True) process_thread.start() - + # Register atexit cleanup so resources are freed even on unexpected exit atexit.register(_run_cleanup) - + # Run the application with patch_stdout for proper output handling try: with patch_stdout(): @@ -4000,7 +4065,7 @@ class HermesCLI: set_sudo_password_callback(None) set_approval_callback(None) # Close session in SQLite - if hasattr(self, '_session_db') and self._session_db and self.agent: + if hasattr(self, "_session_db") and self._session_db and self.agent: try: self._session_db.end_session(self.agent.session_id, "cli_close") except Exception as e: @@ -4013,6 +4078,7 @@ class HermesCLI: # Main Entry Point # ============================================================================ + def main( query: str = None, q: str = None, @@ -4033,7 +4099,7 @@ def main( ): """ Hermes Agent CLI - Interactive AI Assistant - + Args: query: Single query to execute (then exit). Alias: -q q: Shorthand for --query @@ -4050,7 +4116,7 @@ def main( resume: Resume a previous session by its ID (e.g., 20260225_143052_a1b2c3) worktree: Run in an isolated git worktree (for parallel agents). Alias: -w w: Shorthand for --worktree - + Examples: python cli.py # Start interactive mode python cli.py --toolsets web,terminal # Use specific toolsets @@ -4065,11 +4131,13 @@ def main( # Signal to terminal_tool that we're in interactive mode # This enables interactive sudo password prompts with timeout os.environ["HERMES_INTERACTIVE"] = "1" - + # Handle gateway mode (messaging + cron) if gateway: import asyncio + from gateway.run import start_gateway + print("Starting Hermes Gateway (messaging platforms)...") asyncio.run(start_gateway()) return @@ -4097,10 +4165,10 @@ def main( return else: wt_info = None - + # Handle query shorthand query = query or q - + # Parse toolsets - handle both string and tuple/list inputs # Default to hermes-cli toolset which includes cronjob management tools toolsets_list = None @@ -4122,7 +4190,7 @@ def main( toolsets_list = config_cli_toolsets else: toolsets_list = ["hermes-cli"] - + # Create CLI instance cli = HermesCLI( model=model, @@ -4146,21 +4214,21 @@ def main( f"The original repo is at {wt_info['repo_root']}.]" ) cli.system_prompt = (cli.system_prompt or "") + wt_note - + # Handle list commands (don't init agent for these) if list_tools: cli.show_banner() cli.show_tools() sys.exit(0) - + if list_toolsets: cli.show_banner() cli.show_toolsets() sys.exit(0) - + # Register cleanup for single-query mode (interactive mode registers in run()) atexit.register(_run_cleanup) - + # Handle single query mode if query: cli.show_banner() @@ -4168,7 +4236,7 @@ def main( cli.chat(query) cli._print_exit_summary() return - + # Run interactive mode cli.run() diff --git a/cron/__init__.py b/cron/__init__.py index 6a8f3ecbaf..a2ff62c52b 100644 --- a/cron/__init__.py +++ b/cron/__init__.py @@ -15,18 +15,18 @@ duplicate execution if multiple processes overlap. """ from cron.jobs import ( + JOBS_FILE, create_job, get_job, list_jobs, remove_job, update_job, - JOBS_FILE, ) from cron.scheduler import tick __all__ = [ "create_job", - "get_job", + "get_job", "list_jobs", "remove_job", "update_job", diff --git a/cron/jobs.py b/cron/jobs.py index c69ee7cf2f..0b70ca74d5 100644 --- a/cron/jobs.py +++ b/cron/jobs.py @@ -6,18 +6,19 @@ Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md """ import json -import tempfile import os import re +import tempfile import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Optional, Dict, List, Any +from typing import Any from hermes_time import now as _hermes_now try: from croniter import croniter + HAS_CRONITER = True except ImportError: HAS_CRONITER = False @@ -42,37 +43,38 @@ def ensure_dirs(): # Schedule Parsing # ============================================================================= + def parse_duration(s: str) -> int: """ Parse duration string into minutes. - + Examples: "30m" → 30 "2h" → 120 "1d" → 1440 """ s = s.strip().lower() - match = re.match(r'^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$', s) + match = re.match(r"^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$", s) if not match: raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'") - + value = int(match.group(1)) unit = match.group(2)[0] # First char: m, h, or d - - multipliers = {'m': 1, 'h': 60, 'd': 1440} + + multipliers = {"m": 1, "h": 60, "d": 1440} return value * multipliers[unit] -def parse_schedule(schedule: str) -> Dict[str, Any]: +def parse_schedule(schedule: str) -> dict[str, Any]: """ Parse schedule string into structured format. - + Returns dict with: - kind: "once" | "interval" | "cron" - For "once": "run_at" (ISO timestamp) - For "interval": "minutes" (int) - For "cron": "expr" (cron expression) - + Examples: "30m" → once in 30 minutes "2h" → once in 2 hours @@ -84,23 +86,17 @@ def parse_schedule(schedule: str) -> Dict[str, Any]: schedule = schedule.strip() original = schedule schedule_lower = schedule.lower() - + # "every X" pattern → recurring interval if schedule_lower.startswith("every "): duration_str = schedule[6:].strip() minutes = parse_duration(duration_str) - return { - "kind": "interval", - "minutes": minutes, - "display": f"every {minutes}m" - } - + return {"kind": "interval", "minutes": minutes, "display": f"every {minutes}m"} + # Check for cron expression (5 or 6 space-separated fields) # Cron fields: minute hour day month weekday [year] parts = schedule.split() - if len(parts) >= 5 and all( - re.match(r'^[\d\*\-,/]+$', p) for p in parts[:5] - ): + if len(parts) >= 5 and all(re.match(r"^[\d\*\-,/]+$", p) for p in parts[:5]): if not HAS_CRONITER: raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter") # Validate cron expression @@ -108,37 +104,25 @@ def parse_schedule(schedule: str) -> Dict[str, Any]: croniter(schedule) except Exception as e: raise ValueError(f"Invalid cron expression '{schedule}': {e}") - return { - "kind": "cron", - "expr": schedule, - "display": schedule - } - + return {"kind": "cron", "expr": schedule, "display": schedule} + # ISO timestamp (contains T or looks like date) - if 'T' in schedule or re.match(r'^\d{4}-\d{2}-\d{2}', schedule): + if "T" in schedule or re.match(r"^\d{4}-\d{2}-\d{2}", schedule): try: # Parse and validate - dt = datetime.fromisoformat(schedule.replace('Z', '+00:00')) - return { - "kind": "once", - "run_at": dt.isoformat(), - "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}" - } + dt = datetime.fromisoformat(schedule.replace("Z", "+00:00")) + return {"kind": "once", "run_at": dt.isoformat(), "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"} except ValueError as e: raise ValueError(f"Invalid timestamp '{schedule}': {e}") - + # Duration like "30m", "2h", "1d" → one-shot from now try: minutes = parse_duration(schedule) run_at = _hermes_now() + timedelta(minutes=minutes) - return { - "kind": "once", - "run_at": run_at.isoformat(), - "display": f"once in {original}" - } + return {"kind": "once", "run_at": run_at.isoformat(), "display": f"once in {original}"} except ValueError: pass - + raise ValueError( f"Invalid schedule '{original}'. Use:\n" f" - Duration: '30m', '2h', '1d' (one-shot)\n" @@ -161,7 +145,7 @@ def _ensure_aware(dt: datetime) -> datetime: return dt -def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]: +def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -> str | None: """ Compute the next run time for a schedule. @@ -199,26 +183,27 @@ def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None # Job CRUD Operations # ============================================================================= -def load_jobs() -> List[Dict[str, Any]]: + +def load_jobs() -> list[dict[str, Any]]: """Load all jobs from storage.""" ensure_dirs() if not JOBS_FILE.exists(): return [] - + try: - with open(JOBS_FILE, 'r', encoding='utf-8') as f: + with open(JOBS_FILE, encoding="utf-8") as f: data = json.load(f) return data.get("jobs", []) - except (json.JSONDecodeError, IOError): + except (OSError, json.JSONDecodeError): return [] -def save_jobs(jobs: List[Dict[str, Any]]): +def save_jobs(jobs: list[dict[str, Any]]): """Save all jobs to storage.""" ensure_dirs() - fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_') + fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix=".tmp", prefix=".jobs_") try: - with os.fdopen(fd, 'w', encoding='utf-8') as f: + with os.fdopen(fd, "w", encoding="utf-8") as f: json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2) f.flush() os.fsync(f.fileno()) @@ -234,14 +219,14 @@ def save_jobs(jobs: List[Dict[str, Any]]): def create_job( prompt: str, schedule: str, - name: Optional[str] = None, - repeat: Optional[int] = None, - deliver: Optional[str] = None, - origin: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: + name: str | None = None, + repeat: int | None = None, + deliver: str | None = None, + origin: dict[str, Any] | None = None, +) -> dict[str, Any]: """ Create a new cron job. - + Args: prompt: The prompt to run (must be self-contained) schedule: Schedule string (see parse_schedule) @@ -249,23 +234,23 @@ def create_job( repeat: How many times to run (None = forever, 1 = once) deliver: Where to deliver output ("origin", "local", "telegram", etc.) origin: Source info where job was created (for "origin" delivery) - + Returns: The created job dict """ parsed_schedule = parse_schedule(schedule) - + # Auto-set repeat=1 for one-shot schedules if not specified if parsed_schedule["kind"] == "once" and repeat is None: repeat = 1 - + # Default delivery to origin if available, otherwise local if deliver is None: deliver = "origin" if origin else "local" - + job_id = uuid.uuid4().hex[:12] now = _hermes_now().isoformat() - + job = { "id": job_id, "name": name or prompt[:50].strip(), @@ -274,7 +259,7 @@ def create_job( "schedule_display": parsed_schedule.get("display", schedule), "repeat": { "times": repeat, # None = forever - "completed": 0 + "completed": 0, }, "enabled": True, "created_at": now, @@ -286,15 +271,15 @@ def create_job( "deliver": deliver, "origin": origin, # Tracks where job was created for "origin" delivery } - + jobs = load_jobs() jobs.append(job) save_jobs(jobs) - + return job -def get_job(job_id: str) -> Optional[Dict[str, Any]]: +def get_job(job_id: str) -> dict[str, Any] | None: """Get a job by ID.""" jobs = load_jobs() for job in jobs: @@ -303,7 +288,7 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]: return None -def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]: +def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]: """List all jobs, optionally including disabled ones.""" jobs = load_jobs() if not include_disabled: @@ -311,7 +296,7 @@ def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]: return jobs -def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: +def update_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any] | None: """Update a job by ID.""" jobs = load_jobs() for i, job in enumerate(jobs): @@ -333,10 +318,10 @@ def remove_job(job_id: str) -> bool: return False -def mark_job_run(job_id: str, success: bool, error: Optional[str] = None): +def mark_job_run(job_id: str, success: bool, error: str | None = None): """ Mark a job as having been run. - + Updates last_run_at, last_status, increments completed count, computes next_run_at, and auto-deletes if repeat limit reached. """ @@ -347,11 +332,11 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None): job["last_run_at"] = now job["last_status"] = "ok" if success else "error" job["last_error"] = error if not success else None - + # Increment completed count if job.get("repeat"): job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1 - + # Check if we've hit the repeat limit times = job["repeat"].get("times") completed = job["repeat"]["completed"] @@ -360,38 +345,38 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None): jobs.pop(i) save_jobs(jobs) return - + # Compute next run job["next_run_at"] = compute_next_run(job["schedule"], now) - + # If no next run (one-shot completed), disable if job["next_run_at"] is None: job["enabled"] = False - + save_jobs(jobs) return - + save_jobs(jobs) -def get_due_jobs() -> List[Dict[str, Any]]: +def get_due_jobs() -> list[dict[str, Any]]: """Get all jobs that are due to run now.""" now = _hermes_now() jobs = load_jobs() due = [] - + for job in jobs: if not job.get("enabled", True): continue - + next_run = job.get("next_run_at") if not next_run: continue - + next_run_dt = _ensure_aware(datetime.fromisoformat(next_run)) if next_run_dt <= now: due.append(job) - + return due @@ -400,11 +385,11 @@ def save_job_output(job_id: str, output: str): ensure_dirs() job_output_dir = OUTPUT_DIR / job_id job_output_dir.mkdir(parents=True, exist_ok=True) - + timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S") output_file = job_output_dir / f"{timestamp}.md" - - with open(output_file, 'w', encoding='utf-8') as f: + + with open(output_file, "w", encoding="utf-8") as f: f.write(output) - + return output_file diff --git a/cron/scheduler.py b/cron/scheduler.py index 1f96d6443b..8b053934fd 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -23,9 +23,7 @@ except ImportError: import msvcrt except ImportError: msvcrt = None -from datetime import datetime from pathlib import Path -from typing import Optional from hermes_time import now as _hermes_now @@ -44,7 +42,7 @@ _LOCK_DIR = _hermes_home / "cron" _LOCK_FILE = _LOCK_DIR / ".tick.lock" -def _resolve_origin(job: dict) -> Optional[dict]: +def _resolve_origin(job: dict) -> dict | None: """Extract origin info from a job, returning {platform, chat_id, chat_name} or None.""" origin = job.get("origin") if not origin: @@ -87,11 +85,16 @@ def _deliver_result(job: dict, content: str) -> None: # Fall back to home channel chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "") if not chat_id: - logger.warning("Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL ", job["id"], deliver, platform_name.upper()) + logger.warning( + "Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL ", + job["id"], + deliver, + platform_name.upper(), + ) return + from gateway.config import Platform, load_gateway_config from tools.send_message_tool import _send_to_platform - from gateway.config import load_gateway_config, Platform platform_map = { "telegram": Platform.TELEGRAM, @@ -123,6 +126,7 @@ def _deliver_result(job: dict, content: str) -> None: # asyncio.run() fails if there's already a running loop in this thread; # spin up a new thread to avoid that. import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content)) result = future.result(timeout=30) @@ -137,25 +141,26 @@ def _deliver_result(job: dict, content: str) -> None: # Mirror the delivered content into the target's gateway session try: from gateway.mirror import mirror_to_session + mirror_to_session(platform_name, chat_id, content, source_label="cron") except Exception: pass -def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: +def run_job(job: dict) -> tuple[bool, str, str, str | None]: """ Execute a single cron job. - + Returns: Tuple of (success, full_output_doc, final_response, error_message) """ from run_agent import AIAgent - + job_id = job["id"] job_name = job["name"] prompt = job["prompt"] origin = _resolve_origin(job) - + logger.info("Running job '%s' (ID: %s)", job_name, job_id) logger.info("Prompt: %s", prompt[:100]) @@ -170,6 +175,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: # Re-read .env and config.yaml fresh every run so provider/key # changes take effect without a gateway restart. from dotenv import load_dotenv + try: load_dotenv(str(_hermes_home / ".env"), override=True, encoding="utf-8") except UnicodeDecodeError: @@ -181,6 +187,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: _cfg = {} try: import yaml + _cfg_path = str(_hermes_home / "config.yaml") if os.path.exists(_cfg_path): with open(_cfg_path) as _f: @@ -210,12 +217,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "") if prefill_file: import json as _json + pfpath = Path(prefill_file).expanduser() if not pfpath.is_absolute(): pfpath = _hermes_home / pfpath if pfpath.exists(): try: - with open(pfpath, "r", encoding="utf-8") as _pf: + with open(pfpath, encoding="utf-8") as _pf: prefill_messages = _json.load(_pf) if not isinstance(prefill_messages, list): prefill_messages = None @@ -229,9 +237,10 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: pr = _cfg.get("provider_routing", {}) from hermes_cli.runtime_provider import ( - resolve_runtime_provider, format_runtime_provider_error, + resolve_runtime_provider, ) + try: runtime = resolve_runtime_provider( requested=os.getenv("HERMES_INFERENCE_PROVIDER"), @@ -254,20 +263,20 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: providers_order=pr.get("order"), provider_sort=pr.get("sort"), quiet_mode=True, - session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}" + session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}", ) - + result = agent.run_conversation(prompt) - + final_response = result.get("final_response", "") if not final_response: final_response = "(No response generated)" - + output = f"""# Cron Job: {job_name} **Job ID:** {job_id} -**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')} -**Schedule:** {job.get('schedule_display', 'N/A')} +**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")} +**Schedule:** {job.get("schedule_display", "N/A")} ## Prompt @@ -277,19 +286,19 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: {final_response} """ - + logger.info("Job '%s' completed successfully", job_name) return True, output, final_response, None - + except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error("Job '%s' failed: %s", job_name, error_msg) - + output = f"""# Cron Job: {job_name} (FAILED) **Job ID:** {job_id} -**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')} -**Schedule:** {job.get('schedule_display', 'N/A')} +**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")} +**Schedule:** {job.get("schedule_display", "N/A")} ## Prompt @@ -314,13 +323,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: def tick(verbose: bool = True) -> int: """ Check and run all due jobs. - + Uses a file lock so only one tick runs at a time, even if the gateway's in-process ticker and a standalone daemon or manual tick overlap. - + Args: verbose: Whether to print status messages - + Returns: Number of jobs executed (0 if another tick is already running) """ @@ -334,7 +343,7 @@ def tick(verbose: bool = True) -> int: fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) elif msvcrt: msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1) - except (OSError, IOError): + except OSError: logger.debug("Tick skipped — another instance holds the lock") if lock_fd is not None: lock_fd.close() @@ -344,11 +353,11 @@ def tick(verbose: bool = True) -> int: due_jobs = get_due_jobs() if verbose and not due_jobs: - logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S')) + logger.info("%s - No jobs due", _hermes_now().strftime("%H:%M:%S")) return 0 if verbose: - logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs)) + logger.info("%s - %s job(s) due", _hermes_now().strftime("%H:%M:%S"), len(due_jobs)) executed = 0 for job in due_jobs: @@ -360,7 +369,9 @@ def tick(verbose: bool = True) -> int: logger.info("Output saved to: %s", output_file) # Deliver the final response to the origin/target chat - deliver_content = final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}" + deliver_content = ( + final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}" + ) if deliver_content: try: _deliver_result(job, deliver_content) @@ -371,7 +382,7 @@ def tick(verbose: bool = True) -> int: executed += 1 except Exception as e: - logger.error("Error processing job %s: %s", job['id'], e) + logger.error("Error processing job %s: %s", job["id"], e) mark_job_run(job["id"], False, str(e)) return executed @@ -381,7 +392,7 @@ def tick(verbose: bool = True) -> int: elif msvcrt: try: msvcrt.locking(lock_fd.fileno(), msvcrt.LK_UNLCK, 1) - except (OSError, IOError): + except OSError: pass lock_fd.close() diff --git a/gateway/__init__.py b/gateway/__init__.py index 8b6d988934..3b970ffa68 100644 --- a/gateway/__init__.py +++ b/gateway/__init__.py @@ -9,19 +9,18 @@ to various messaging platforms (Telegram, Discord, WhatsApp) with: - Platform-specific toolsets (different capabilities per platform) """ -from .config import GatewayConfig, PlatformConfig, HomeChannel, load_gateway_config +from .config import GatewayConfig, HomeChannel, PlatformConfig, SessionResetPolicy, load_gateway_config +from .delivery import DeliveryRouter, DeliveryTarget from .session import ( SessionContext, SessionStore, - SessionResetPolicy, build_session_context_prompt, ) -from .delivery import DeliveryRouter, DeliveryTarget __all__ = [ # Config "GatewayConfig", - "PlatformConfig", + "PlatformConfig", "HomeChannel", "load_gateway_config", # Session diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 31406a7dec..58627b7275 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -10,7 +10,7 @@ import json import logging from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -21,7 +21,8 @@ DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json" # Build / refresh # --------------------------------------------------------------------------- -def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: + +def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]: """ Build a channel directory from connected platform adapters and session data. @@ -29,7 +30,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: """ from gateway.config import Platform - platforms: Dict[str, List[Dict[str, str]]] = {} + platforms: dict[str, list[dict[str, str]]] = {} for platform, adapter in adapters.items(): try: @@ -60,7 +61,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: return directory -def _build_discord(adapter) -> List[Dict[str, str]]: +def _build_discord(adapter) -> list[dict[str, str]]: """Enumerate all text channels the Discord bot can see.""" channels = [] client = getattr(adapter, "_client", None) @@ -74,12 +75,14 @@ def _build_discord(adapter) -> List[Dict[str, str]]: for guild in client.guilds: for ch in guild.text_channels: - channels.append({ - "id": str(ch.id), - "name": ch.name, - "guild": guild.name, - "type": "channel", - }) + channels.append( + { + "id": str(ch.id), + "name": ch.name, + "guild": guild.name, + "type": "channel", + } + ) # Also include DM-capable users we've interacted with is not # feasible via guild enumeration; those come from sessions. @@ -88,7 +91,7 @@ def _build_discord(adapter) -> List[Dict[str, str]]: return channels -def _build_slack(adapter) -> List[Dict[str, str]]: +def _build_slack(adapter) -> list[dict[str, str]]: """List Slack channels the bot has joined.""" channels = [] # Slack adapter may expose a web client @@ -97,7 +100,6 @@ def _build_slack(adapter) -> List[Dict[str, str]]: return _build_from_sessions("slack") try: - import asyncio from tools.send_message_tool import _send_slack # noqa: F401 # Use the Slack Web API directly if available except Exception: @@ -107,7 +109,7 @@ def _build_slack(adapter) -> List[Dict[str, str]]: return _build_from_sessions("slack") -def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]: +def _build_from_sessions(platform_name: str) -> list[dict[str, str]]: """Pull known channels/contacts from sessions.json origin data.""" sessions_path = Path.home() / ".hermes" / "sessions" / "sessions.json" if not sessions_path.exists(): @@ -127,11 +129,13 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]: if not chat_id or chat_id in seen_ids: continue seen_ids.add(chat_id) - entries.append({ - "id": str(chat_id), - "name": origin.get("chat_name") or origin.get("user_name") or str(chat_id), - "type": session.get("chat_type", "dm"), - }) + entries.append( + { + "id": str(chat_id), + "name": origin.get("chat_name") or origin.get("user_name") or str(chat_id), + "type": session.get("chat_type", "dm"), + } + ) except Exception as e: logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e) @@ -142,7 +146,8 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]: # Read / resolve # --------------------------------------------------------------------------- -def load_directory() -> Dict[str, Any]: + +def load_directory() -> dict[str, Any]: """Load the cached channel directory from disk.""" if not DIRECTORY_PATH.exists(): return {"updated_at": None, "platforms": {}} @@ -153,7 +158,7 @@ def load_directory() -> Dict[str, Any]: return {"updated_at": None, "platforms": {}} -def resolve_channel_name(platform_name: str, name: str) -> Optional[str]: +def resolve_channel_name(platform_name: str, name: str) -> str | None: """ Resolve a human-friendly channel name to a numeric ID. @@ -206,8 +211,8 @@ def format_directory_for_display() -> str: # Group Discord channels by guild if plat_name == "discord": - guilds: Dict[str, List] = {} - dms: List = [] + guilds: dict[str, list] = {} + dms: list = [] for ch in channels: guild = ch.get("guild") if guild: diff --git a/gateway/config.py b/gateway/config.py index 9a517f81b2..1d0668f5d9 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -8,19 +8,20 @@ Handles loading and validating configuration for: - Delivery preferences """ +import json import logging import os -import json -from pathlib import Path from dataclasses import dataclass, field -from typing import Dict, List, Optional, Any from enum import Enum +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) class Platform(Enum): """Supported messaging platforms.""" + LOCAL = "local" TELEGRAM = "telegram" DISCORD = "discord" @@ -34,23 +35,24 @@ class Platform(Enum): class HomeChannel: """ Default destination for a platform. - + When a cron job specifies deliver="telegram" without a specific chat ID, messages are sent to this home channel. """ + platform: Platform chat_id: str name: str # Human-readable name for display - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: return { "platform": self.platform.value, "chat_id": self.chat_id, "name": self.name, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "HomeChannel": + def from_dict(cls, data: dict[str, Any]) -> "HomeChannel": return cls( platform=Platform(data["platform"]), chat_id=str(data["chat_id"]), @@ -62,26 +64,27 @@ class HomeChannel: class SessionResetPolicy: """ Controls when sessions reset (lose context). - + Modes: - "daily": Reset at a specific hour each day - "idle": Reset after N minutes of inactivity - "both": Whichever triggers first (daily boundary OR idle timeout) - "none": Never auto-reset (context managed only by compression) """ + mode: str = "both" # "daily", "idle", "both", or "none" at_hour: int = 4 # Hour for daily reset (0-23, local time) idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours) - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: return { "mode": self.mode, "at_hour": self.at_hour, "idle_minutes": self.idle_minutes, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy": + def from_dict(cls, data: dict[str, Any]) -> "SessionResetPolicy": return cls( mode=data.get("mode", "both"), at_hour=data.get("at_hour", 4), @@ -92,15 +95,16 @@ class SessionResetPolicy: @dataclass class PlatformConfig: """Configuration for a single messaging platform.""" + enabled: bool = False - token: Optional[str] = None # Bot token (Telegram, Discord) - api_key: Optional[str] = None # API key if different from token - home_channel: Optional[HomeChannel] = None - + token: str | None = None # Bot token (Telegram, Discord) + api_key: str | None = None # API key if different from token + home_channel: HomeChannel | None = None + # Platform-specific settings - extra: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: + extra: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: result = { "enabled": self.enabled, "extra": self.extra, @@ -112,13 +116,13 @@ class PlatformConfig: if self.home_channel: result["home_channel"] = self.home_channel.to_dict() return result - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PlatformConfig": + def from_dict(cls, data: dict[str, Any]) -> "PlatformConfig": home_channel = None if "home_channel" in data: home_channel = HomeChannel.from_dict(data["home_channel"]) - + return cls( enabled=data.get("enabled", False), token=data.get("token"), @@ -132,89 +136,80 @@ class PlatformConfig: class GatewayConfig: """ Main gateway configuration. - + Manages all platform connections, session policies, and delivery settings. """ + # Platform configurations - platforms: Dict[Platform, PlatformConfig] = field(default_factory=dict) - + platforms: dict[Platform, PlatformConfig] = field(default_factory=dict) + # Session reset policies by type default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy) - reset_by_type: Dict[str, SessionResetPolicy] = field(default_factory=dict) - reset_by_platform: Dict[Platform, SessionResetPolicy] = field(default_factory=dict) - + reset_by_type: dict[str, SessionResetPolicy] = field(default_factory=dict) + reset_by_platform: dict[Platform, SessionResetPolicy] = field(default_factory=dict) + # Reset trigger commands - reset_triggers: List[str] = field(default_factory=lambda: ["/new", "/reset"]) - + reset_triggers: list[str] = field(default_factory=lambda: ["/new", "/reset"]) + # Storage paths sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions") - + # Delivery settings always_log_local: bool = True # Always save cron outputs to local files - - def get_connected_platforms(self) -> List[Platform]: + + def get_connected_platforms(self) -> list[Platform]: """Return list of platforms that are enabled and configured.""" connected = [] for platform, config in self.platforms.items(): if not config.enabled: continue # Platforms that use token/api_key auth - if config.token or config.api_key: - connected.append(platform) - # WhatsApp uses enabled flag only (bridge handles auth) - elif platform == Platform.WHATSAPP: - connected.append(platform) - # Signal uses extra dict for config (http_url + account) - elif platform == Platform.SIGNAL and config.extra.get("http_url"): + if ( + config.token + or config.api_key + or platform == Platform.WHATSAPP + or platform == Platform.SIGNAL + and config.extra.get("http_url") + ): connected.append(platform) return connected - - def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: + + def get_home_channel(self, platform: Platform) -> HomeChannel | None: """Get the home channel for a platform.""" config = self.platforms.get(platform) if config: return config.home_channel return None - - def get_reset_policy( - self, - platform: Optional[Platform] = None, - session_type: Optional[str] = None - ) -> SessionResetPolicy: + + def get_reset_policy(self, platform: Platform | None = None, session_type: str | None = None) -> SessionResetPolicy: """ Get the appropriate reset policy for a session. - + Priority: platform override > type override > default """ # Platform-specific override takes precedence if platform and platform in self.reset_by_platform: return self.reset_by_platform[platform] - + # Type-specific override (dm, group, thread) if session_type and session_type in self.reset_by_type: return self.reset_by_type[session_type] - + return self.default_reset_policy - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: return { - "platforms": { - p.value: c.to_dict() for p, c in self.platforms.items() - }, + "platforms": {p.value: c.to_dict() for p, c in self.platforms.items()}, "default_reset_policy": self.default_reset_policy.to_dict(), - "reset_by_type": { - k: v.to_dict() for k, v in self.reset_by_type.items() - }, - "reset_by_platform": { - p.value: v.to_dict() for p, v in self.reset_by_platform.items() - }, + "reset_by_type": {k: v.to_dict() for k, v in self.reset_by_type.items()}, + "reset_by_platform": {p.value: v.to_dict() for p, v in self.reset_by_platform.items()}, "reset_triggers": self.reset_triggers, "sessions_dir": str(self.sessions_dir), "always_log_local": self.always_log_local, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GatewayConfig": + def from_dict(cls, data: dict[str, Any]) -> "GatewayConfig": platforms = {} for platform_name, platform_data in data.get("platforms", {}).items(): try: @@ -222,11 +217,11 @@ class GatewayConfig: platforms[platform] = PlatformConfig.from_dict(platform_data) except ValueError: pass # Skip unknown platforms - + reset_by_type = {} for type_name, policy_data in data.get("reset_by_type", {}).items(): reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data) - + reset_by_platform = {} for platform_name, policy_data in data.get("reset_by_platform", {}).items(): try: @@ -234,15 +229,15 @@ class GatewayConfig: reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data) except ValueError: pass - + default_policy = SessionResetPolicy() if "default_reset_policy" in data: default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"]) - + sessions_dir = Path.home() / ".hermes" / "sessions" if "sessions_dir" in data: sessions_dir = Path(data["sessions_dir"]) - + return cls( platforms=platforms, default_reset_policy=default_policy, @@ -257,7 +252,7 @@ class GatewayConfig: def load_gateway_config() -> GatewayConfig: """ Load gateway configuration from multiple sources. - + Priority (highest to lowest): 1. Environment variables 2. ~/.hermes/gateway.json @@ -265,22 +260,23 @@ def load_gateway_config() -> GatewayConfig: 4. Defaults """ config = GatewayConfig() - + # Try loading from ~/.hermes/gateway.json gateway_config_path = Path.home() / ".hermes" / "gateway.json" if gateway_config_path.exists(): try: - with open(gateway_config_path, "r") as f: + with open(gateway_config_path) as f: data = json.load(f) config = GatewayConfig.from_dict(data) except Exception as e: print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}") - + # Bridge session_reset from config.yaml (the user-facing config file) # into the gateway config. config.yaml takes precedence over gateway.json # for session reset policy since that's where hermes setup writes it. try: import yaml + config_yaml_path = Path.home() / ".hermes" / "config.yaml" if config_yaml_path.exists(): with open(config_yaml_path) as f: @@ -293,14 +289,12 @@ def load_gateway_config() -> GatewayConfig: # Override with environment variables _apply_env_overrides(config) - + # --- Validate loaded values --- policy = config.default_reset_policy if not (0 <= policy.at_hour <= 23): - logger.warning( - "Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour - ) + logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour) policy.at_hour = 4 if policy.idle_minutes is None or policy.idle_minutes <= 0: @@ -323,9 +317,9 @@ def load_gateway_config() -> GatewayConfig: env_name = _token_env_names.get(platform) if env_name and pconfig.token is not None and not pconfig.token.strip(): logger.warning( - "%s is enabled but %s is empty. " - "The adapter will likely fail to connect.", - platform.value, env_name, + "%s is enabled but %s is empty. The adapter will likely fail to connect.", + platform.value, + env_name, ) return config @@ -333,7 +327,7 @@ def load_gateway_config() -> GatewayConfig: def _apply_env_overrides(config: GatewayConfig) -> None: """Apply environment variable overrides to config.""" - + # Telegram telegram_token = os.getenv("TELEGRAM_BOT_TOKEN") if telegram_token: @@ -341,7 +335,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.TELEGRAM] = PlatformConfig() config.platforms[Platform.TELEGRAM].enabled = True config.platforms[Platform.TELEGRAM].token = telegram_token - + telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL") if telegram_home and Platform.TELEGRAM in config.platforms: config.platforms[Platform.TELEGRAM].home_channel = HomeChannel( @@ -349,7 +343,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None: chat_id=telegram_home, name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"), ) - + # Discord discord_token = os.getenv("DISCORD_BOT_TOKEN") if discord_token: @@ -357,7 +351,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.DISCORD] = PlatformConfig() config.platforms[Platform.DISCORD].enabled = True config.platforms[Platform.DISCORD].token = discord_token - + discord_home = os.getenv("DISCORD_HOME_CHANNEL") if discord_home and Platform.DISCORD in config.platforms: config.platforms[Platform.DISCORD].home_channel = HomeChannel( @@ -365,14 +359,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None: chat_id=discord_home, name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"), ) - + # WhatsApp (typically uses different auth mechanism) whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes") if whatsapp_enabled: if Platform.WHATSAPP not in config.platforms: config.platforms[Platform.WHATSAPP] = PlatformConfig() config.platforms[Platform.WHATSAPP].enabled = True - + # Slack slack_token = os.getenv("SLACK_BOT_TOKEN") if slack_token: @@ -388,7 +382,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None: chat_id=slack_home, name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""), ) - + # Signal signal_url = os.getenv("SIGNAL_HTTP_URL") signal_account = os.getenv("SIGNAL_ACCOUNT") @@ -396,11 +390,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: if Platform.SIGNAL not in config.platforms: config.platforms[Platform.SIGNAL] = PlatformConfig() config.platforms[Platform.SIGNAL].enabled = True - config.platforms[Platform.SIGNAL].extra.update({ - "http_url": signal_url, - "account": signal_account, - "ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"), - }) + config.platforms[Platform.SIGNAL].extra.update( + { + "http_url": signal_url, + "account": signal_account, + "ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"), + } + ) signal_home = os.getenv("SIGNAL_HOME_CHANNEL") if signal_home: config.platforms[Platform.SIGNAL].home_channel = HomeChannel( @@ -427,7 +423,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.default_reset_policy.idle_minutes = int(idle_minutes) except ValueError: pass - + reset_hour = os.getenv("SESSION_RESET_HOUR") if reset_hour: try: @@ -440,6 +436,6 @@ def save_gateway_config(config: GatewayConfig) -> None: """Save gateway configuration to ~/.hermes/gateway.json.""" gateway_config_path = Path.home() / ".hermes" / "gateway.json" gateway_config_path.parent.mkdir(parents=True, exist_ok=True) - + with open(gateway_config_path, "w") as f: json.dump(config.to_dict(), f, indent=2) diff --git a/gateway/delivery.py b/gateway/delivery.py index 0093c1fb09..bb2b52cfa9 100644 --- a/gateway/delivery.py +++ b/gateway/delivery.py @@ -9,18 +9,17 @@ Routes messages to the appropriate destination based on: """ import logging -from pathlib import Path -from datetime import datetime from dataclasses import dataclass -from typing import Dict, List, Optional, Any, Union -from enum import Enum +from datetime import datetime +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) MAX_PLATFORM_OUTPUT = 4000 TRUNCATED_VISIBLE = 3800 -from .config import Platform, GatewayConfig +from .config import GatewayConfig, Platform from .session import SessionSource @@ -28,23 +27,24 @@ from .session import SessionSource class DeliveryTarget: """ A single delivery target. - + Represents where a message should be sent: - "origin" → back to source - "local" → save to local files - "telegram" → Telegram home channel - "telegram:123456" → specific Telegram chat """ + platform: Platform - chat_id: Optional[str] = None # None means use home channel + chat_id: str | None = None # None means use home channel is_origin: bool = False is_explicit: bool = False # True if chat_id was explicitly specified - + @classmethod - def parse(cls, target: str, origin: Optional[SessionSource] = None) -> "DeliveryTarget": + def parse(cls, target: str, origin: SessionSource | None = None) -> "DeliveryTarget": """ Parse a delivery target string. - + Formats: - "origin" → back to source - "local" → local files only @@ -52,7 +52,7 @@ class DeliveryTarget: - "telegram:123456" → specific Telegram chat """ target = target.strip().lower() - + if target == "origin": if origin: return cls( @@ -63,10 +63,10 @@ class DeliveryTarget: else: # Fallback to local if no origin return cls(platform=Platform.LOCAL, is_origin=True) - + if target == "local": return cls(platform=Platform.LOCAL) - + # Check for platform:chat_id format if ":" in target: platform_str, chat_id = target.split(":", 1) @@ -76,7 +76,7 @@ class DeliveryTarget: except ValueError: # Unknown platform, treat as local return cls(platform=Platform.LOCAL) - + # Just a platform name (use home channel) try: platform = Platform(target) @@ -84,7 +84,7 @@ class DeliveryTarget: except ValueError: # Unknown platform, treat as local return cls(platform=Platform.LOCAL) - + def to_string(self) -> str: """Convert back to string format.""" if self.is_origin: @@ -99,15 +99,15 @@ class DeliveryTarget: class DeliveryRouter: """ Routes messages to appropriate destinations. - + Handles the logic of resolving delivery targets and dispatching messages to the right platform adapters. """ - - def __init__(self, config: GatewayConfig, adapters: Dict[Platform, Any] = None): + + def __init__(self, config: GatewayConfig, adapters: dict[Platform, Any] = None): """ Initialize the delivery router. - + Args: config: Gateway configuration adapters: Dict mapping platforms to their adapter instances @@ -115,31 +115,27 @@ class DeliveryRouter: self.config = config self.adapters = adapters or {} self.output_dir = Path.home() / ".hermes" / "cron" / "output" - - def resolve_targets( - self, - deliver: Union[str, List[str]], - origin: Optional[SessionSource] = None - ) -> List[DeliveryTarget]: + + def resolve_targets(self, deliver: str | list[str], origin: SessionSource | None = None) -> list[DeliveryTarget]: """ Resolve delivery specification to concrete targets. - + Args: deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc. origin: The source where the request originated (for "origin" target) - + Returns: List of resolved delivery targets """ if isinstance(deliver, str): deliver = [deliver] - + targets = [] seen_platforms = set() - + for target_str in deliver: target = DeliveryTarget.parse(target_str, origin) - + # Resolve home channel if needed if target.chat_id is None and target.platform != Platform.LOCAL: home = self.config.get_home_channel(target.platform) @@ -148,109 +144,96 @@ class DeliveryRouter: else: # No home channel configured, skip this platform continue - + # Deduplicate key = (target.platform, target.chat_id) if key not in seen_platforms: seen_platforms.add(key) targets.append(target) - + # Always include local if configured if self.config.always_log_local: local_key = (Platform.LOCAL, None) if local_key not in seen_platforms: targets.append(DeliveryTarget(platform=Platform.LOCAL)) - + return targets - + async def deliver( self, content: str, - targets: List[DeliveryTarget], - job_id: Optional[str] = None, - job_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + targets: list[DeliveryTarget], + job_id: str | None = None, + job_name: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: """ Deliver content to all specified targets. - + Args: content: The message/output to deliver targets: List of delivery targets job_id: Optional job ID (for cron jobs) job_name: Optional job name metadata: Additional metadata to include - + Returns: Dict with delivery results per target """ results = {} - + for target in targets: try: if target.platform == Platform.LOCAL: result = self._deliver_local(content, job_id, job_name, metadata) else: result = await self._deliver_to_platform(target, content, metadata) - - results[target.to_string()] = { - "success": True, - "result": result - } + + results[target.to_string()] = {"success": True, "result": result} except Exception as e: - results[target.to_string()] = { - "success": False, - "error": str(e) - } - + results[target.to_string()] = {"success": False, "error": str(e)} + return results - + def _deliver_local( - self, - content: str, - job_id: Optional[str], - job_name: Optional[str], - metadata: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: + self, content: str, job_id: str | None, job_name: str | None, metadata: dict[str, Any] | None + ) -> dict[str, Any]: """Save content to local files.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + if job_id: output_path = self.output_dir / job_id / f"{timestamp}.md" else: output_path = self.output_dir / "misc" / f"{timestamp}.md" - + output_path.parent.mkdir(parents=True, exist_ok=True) - + # Build the output document lines = [] if job_name: lines.append(f"# {job_name}") else: lines.append("# Delivery Output") - + lines.append("") lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - + if job_id: lines.append(f"**Job ID:** {job_id}") - + if metadata: for key, value in metadata.items(): lines.append(f"**{key}:** {value}") - + lines.append("") lines.append("---") lines.append("") lines.append(content) - + output_path.write_text("\n".join(lines)) - - return { - "path": str(output_path), - "timestamp": timestamp - } - + + return {"path": str(output_path), "timestamp": timestamp} + def _save_full_output(self, content: str, job_id: str) -> Path: """Save full cron output to disk and return the file path.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -261,41 +244,33 @@ class DeliveryRouter: return path async def _deliver_to_platform( - self, - target: DeliveryTarget, - content: str, - metadata: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: + self, target: DeliveryTarget, content: str, metadata: dict[str, Any] | None + ) -> dict[str, Any]: """Deliver content to a messaging platform.""" adapter = self.adapters.get(target.platform) - + if not adapter: raise ValueError(f"No adapter configured for {target.platform.value}") - + if not target.chat_id: raise ValueError(f"No chat ID for {target.platform.value} delivery") - + # Guard: truncate oversized cron output to stay within platform limits if len(content) > MAX_PLATFORM_OUTPUT: job_id = (metadata or {}).get("job_id", "unknown") saved_path = self._save_full_output(content, job_id) logger.info("Cron output truncated (%d chars) — full output: %s", len(content), saved_path) - content = ( - content[:TRUNCATED_VISIBLE] - + f"\n\n... [truncated, full output saved to {saved_path}]" - ) - + content = content[:TRUNCATED_VISIBLE] + f"\n\n... [truncated, full output saved to {saved_path}]" + return await adapter.send(target.chat_id, content, metadata=metadata) def parse_deliver_spec( - deliver: Optional[Union[str, List[str]]], - origin: Optional[SessionSource] = None, - default: str = "origin" -) -> Union[str, List[str]]: + deliver: str | list[str] | None, origin: SessionSource | None = None, default: str = "origin" +) -> str | list[str]: """ Normalize a delivery specification. - + If None or empty, returns the default. """ if not deliver: @@ -303,17 +278,14 @@ def parse_deliver_spec( return deliver -def build_delivery_context_for_tool( - config: GatewayConfig, - origin: Optional[SessionSource] = None -) -> Dict[str, Any]: +def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource | None = None) -> dict[str, Any]: """ Build context for the schedule_cronjob tool to understand delivery options. - + This is passed to the tool so it can validate and explain delivery targets. """ connected = config.get_connected_platforms() - + options = { "origin": { "description": "Back to where this job was created", @@ -322,9 +294,9 @@ def build_delivery_context_for_tool( "local": { "description": "Save to local files only", "available": True, - } + }, } - + for platform in connected: home = config.get_home_channel(platform) options[platform.value] = { @@ -332,7 +304,7 @@ def build_delivery_context_for_tool( "available": True, "home_channel": home.to_dict() if home else None, } - + return { "origin": origin.to_dict() if origin else None, "options": options, diff --git a/gateway/hooks.py b/gateway/hooks.py index d2face15c5..c32e24ce49 100644 --- a/gateway/hooks.py +++ b/gateway/hooks.py @@ -21,12 +21,12 @@ Errors in hooks are caught and logged but never block the main pipeline. import asyncio import importlib.util import os +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any import yaml - HOOKS_DIR = Path(os.path.expanduser("~/.hermes/hooks")) @@ -42,11 +42,11 @@ class HookRegistry: def __init__(self): # event_type -> [handler_fn, ...] - self._handlers: Dict[str, List[Callable]] = {} - self._loaded_hooks: List[dict] = [] # metadata for listing + self._handlers: dict[str, list[Callable]] = {} + self._loaded_hooks: list[dict] = [] # metadata for listing @property - def loaded_hooks(self) -> List[dict]: + def loaded_hooks(self) -> list[dict]: """Return metadata about all loaded hooks.""" return list(self._loaded_hooks) @@ -84,9 +84,7 @@ class HookRegistry: continue # Dynamically load the handler module - spec = importlib.util.spec_from_file_location( - f"hermes_hook_{hook_name}", handler_path - ) + spec = importlib.util.spec_from_file_location(f"hermes_hook_{hook_name}", handler_path) if spec is None or spec.loader is None: print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True) continue @@ -103,19 +101,21 @@ class HookRegistry: for event in events: self._handlers.setdefault(event, []).append(handle_fn) - self._loaded_hooks.append({ - "name": hook_name, - "description": manifest.get("description", ""), - "events": events, - "path": str(hook_dir), - }) + self._loaded_hooks.append( + { + "name": hook_name, + "description": manifest.get("description", ""), + "events": events, + "path": str(hook_dir), + } + ) print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True) except Exception as e: print(f"[hooks] Error loading hook {hook_dir.name}: {e}", flush=True) - async def emit(self, event_type: str, context: Optional[Dict[str, Any]] = None) -> None: + async def emit(self, event_type: str, context: dict[str, Any] | None = None) -> None: """ Fire all handlers registered for an event. diff --git a/gateway/mirror.py b/gateway/mirror.py index 527fc2c13c..133bc84e73 100644 --- a/gateway/mirror.py +++ b/gateway/mirror.py @@ -13,7 +13,6 @@ import json import logging from datetime import datetime from pathlib import Path -from typing import Optional logger = logging.getLogger(__name__) @@ -61,7 +60,7 @@ def mirror_to_session( return False -def _find_session_id(platform: str, chat_id: str) -> Optional[str]: +def _find_session_id(platform: str, chat_id: str) -> str | None: """ Find the active session_id for a platform + chat_id pair. @@ -113,6 +112,7 @@ def _append_to_sqlite(session_id: str, message: dict) -> None: """Append a message to the SQLite session database.""" try: from hermes_state import SessionDB + db = SessionDB() db.append_message( session_id=session_id, diff --git a/gateway/pairing.py b/gateway/pairing.py index b1e066ffe1..7f46c75a50 100644 --- a/gateway/pairing.py +++ b/gateway/pairing.py @@ -23,21 +23,19 @@ import os import secrets import time from pathlib import Path -from typing import Optional - # Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" CODE_LENGTH = 8 # Timing constants -CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour -RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes -LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures +CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour +RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes +LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures # Limits -MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform -MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout +MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform +MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout PAIRING_DIR = Path(os.path.expanduser("~/.hermes/pairing")) @@ -123,9 +121,7 @@ class PairingStore: # ----- Pending codes ----- - def generate_code( - self, platform: str, user_id: str, user_name: str = "" - ) -> Optional[str]: + def generate_code(self, platform: str, user_id: str, user_name: str = "") -> str | None: """ Generate a pairing code for a new user. @@ -165,7 +161,7 @@ class PairingStore: return code - def approve_code(self, platform: str, code: str) -> Optional[dict]: + def approve_code(self, platform: str, code: str) -> dict | None: """ Approve a pairing code. Adds the user to the approved list. @@ -199,13 +195,15 @@ class PairingStore: pending = self._load_json(self._pending_path(p)) for code, info in pending.items(): age_min = int((time.time() - info["created_at"]) / 60) - results.append({ - "platform": p, - "code": code, - "user_id": info["user_id"], - "user_name": info.get("user_name", ""), - "age_minutes": age_min, - }) + results.append( + { + "platform": p, + "code": code, + "user_id": info["user_id"], + "user_name": info.get("user_name", ""), + "age_minutes": age_min, + } + ) return results def clear_pending(self, platform: str = None) -> int: @@ -251,8 +249,11 @@ class PairingStore: lockout_key = f"_lockout:{platform}" limits[lockout_key] = time.time() + LOCKOUT_SECONDS limits[fail_key] = 0 # Reset counter - print(f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s " - f"after {MAX_FAILED_ATTEMPTS} failed attempts", flush=True) + print( + f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s " + f"after {MAX_FAILED_ATTEMPTS} failed attempts", + flush=True, + ) self._save_json(self._rate_limit_path(), limits) # ----- Cleanup ----- @@ -262,10 +263,7 @@ class PairingStore: path = self._pending_path(platform) pending = self._load_json(path) now = time.time() - expired = [ - code for code, info in pending.items() - if (now - info["created_at"]) > CODE_TTL_SECONDS - ] + expired = [code for code, info in pending.items() if (now - info["created_at"]) > CODE_TTL_SECONDS] if expired: for code in expired: del pending[code] diff --git a/gateway/platforms/ADDING_A_PLATFORM.md b/gateway/platforms/ADDING_A_PLATFORM.md index dadd9890d9..da549d6218 100644 --- a/gateway/platforms/ADDING_A_PLATFORM.md +++ b/gateway/platforms/ADDING_A_PLATFORM.md @@ -303,8 +303,8 @@ Optional but valuable: After implementing everything, verify with: ```bash -# All tests pass -python -m pytest tests/ -q +# All checks pass (lint + test) +make check # Grep for your platform name to find any missed integration points grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \ diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 4dd9cd25d9..3a9fe873a4 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -13,20 +13,20 @@ import uuid from abc import ABC, abstractmethod logger = logging.getLogger(__name__) +import sys +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime -from pathlib import Path -from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple from enum import Enum - -import sys +from pathlib import Path from pathlib import Path as _Path +from typing import Any + sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig from gateway.session import SessionSource - # --------------------------------------------------------------------------- # Image cache utilities # @@ -251,6 +251,7 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int: class MessageType(Enum): """Types of incoming messages.""" + TEXT = "text" LOCATION = "location" PHOTO = "photo" @@ -266,42 +267,43 @@ class MessageType(Enum): class MessageEvent: """ Incoming message from a platform. - + Normalized representation that all adapters produce. """ + # Message content text: str message_type: MessageType = MessageType.TEXT - + # Source information source: SessionSource = None - + # Original platform data raw_message: Any = None - message_id: Optional[str] = None - + message_id: str | None = None + # Media attachments - media_urls: List[str] = field(default_factory=list) - media_types: List[str] = field(default_factory=list) - + media_urls: list[str] = field(default_factory=list) + media_types: list[str] = field(default_factory=list) + # Reply context - reply_to_message_id: Optional[str] = None - + reply_to_message_id: str | None = None + # Timestamps timestamp: datetime = field(default_factory=datetime.now) - + def is_command(self) -> bool: """Check if this is a command message (e.g., /new, /reset).""" return self.text.startswith("/") - - def get_command(self) -> Optional[str]: + + def get_command(self) -> str | None: """Extract command name if this is a command message.""" if not self.is_command(): return None # Split on space and get first word, strip the / parts = self.text.split(maxsplit=1) return parts[0][1:].lower() if parts else None - + def get_command_args(self) -> str: """Get the arguments after a command.""" if not self.is_command(): @@ -310,91 +312,88 @@ class MessageEvent: return parts[1] if len(parts) > 1 else "" -@dataclass +@dataclass class SendResult: """Result of sending a message.""" + success: bool - message_id: Optional[str] = None - error: Optional[str] = None + message_id: str | None = None + error: str | None = None raw_response: Any = None # Type for message handlers -MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]] +MessageHandler = Callable[[MessageEvent], Awaitable[str | None]] class BasePlatformAdapter(ABC): """ Base class for platform adapters. - + Subclasses implement platform-specific logic for: - Connecting and authenticating - Receiving messages - Sending messages/responses - Handling media """ - + def __init__(self, config: PlatformConfig, platform: Platform): self.config = config self.platform = platform - self._message_handler: Optional[MessageHandler] = None + self._message_handler: MessageHandler | None = None self._running = False - + # Track active message handlers per session for interrupt support # Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt) - self._active_sessions: Dict[str, asyncio.Event] = {} - self._pending_messages: Dict[str, MessageEvent] = {} - + self._active_sessions: dict[str, asyncio.Event] = {} + self._pending_messages: dict[str, MessageEvent] = {} + @property def name(self) -> str: """Human-readable name for this adapter.""" return self.platform.value.title() - + @property def is_connected(self) -> bool: """Check if adapter is currently connected.""" return self._running - + def set_message_handler(self, handler: MessageHandler) -> None: """ Set the handler for incoming messages. - + The handler receives a MessageEvent and should return an optional response string. """ self._message_handler = handler - + @abstractmethod async def connect(self) -> bool: """ Connect to the platform and start receiving messages. - + Returns True if connection was successful. """ pass - + @abstractmethod async def disconnect(self) -> None: """Disconnect from the platform.""" pass - + @abstractmethod async def send( - self, - chat_id: str, - content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None ) -> SendResult: """ Send a message to a chat. - + Args: chat_id: The chat/channel ID to send to content: Message content (may be markdown) reply_to: Optional message ID to reply to metadata: Additional platform-specific options - + Returns: SendResult with success status and message ID """ @@ -416,21 +415,21 @@ class BasePlatformAdapter(ABC): async def send_typing(self, chat_id: str) -> None: """ Send a typing indicator. - + Override in subclasses if the platform supports it. """ pass - + async def send_image( self, chat_id: str, image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send an image natively via the platform API. - + Override in subclasses to send images as proper attachments instead of plain-text URLs. Default falls back to sending the URL as a text message. @@ -438,87 +437,91 @@ class BasePlatformAdapter(ABC): # Fallback: send URL as text (subclasses override for native images) text = f"{caption}\n{image_url}" if caption else image_url return await self.send(chat_id=chat_id, content=text, reply_to=reply_to) - + async def send_animation( self, chat_id: str, animation_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send an animated GIF natively via the platform API. - + Override in subclasses to send GIFs as proper animations (e.g., Telegram send_animation) so they auto-play inline. Default falls back to send_image. """ return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to) - + @staticmethod def _is_animation_url(url: str) -> bool: """Check if a URL points to an animated GIF (vs a static image).""" - lower = url.lower().split('?')[0] # Strip query params - return lower.endswith('.gif') + lower = url.lower().split("?")[0] # Strip query params + return lower.endswith(".gif") @staticmethod - def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]: + def extract_images(content: str) -> tuple[list[tuple[str, str]], str]: """ Extract image URLs from markdown and HTML image tags in a response. - + Finds patterns like: - ![alt text](https://example.com/image.png) - - - + Args: content: The response text to scan. - + Returns: Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed). """ images = [] cleaned = content - + # Match markdown images: ![alt](url) - md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)' + md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)" for match in re.finditer(md_pattern, content): alt_text = match.group(1) url = match.group(2) # Only extract URLs that look like actual images - if any(url.lower().endswith(ext) or ext in url.lower() for ext in - ['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']): + if any( + url.lower().endswith(ext) or ext in url.lower() + for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"] + ): images.append((url, alt_text)) - + # Match HTML img tags: or or html_pattern = r']+)["\']?\s*/?>\s*(?:)?' for match in re.finditer(html_pattern, content): url = match.group(1) images.append((url, "")) - + # Remove only the matched image tags from content (not all markdown images) if images: extracted_urls = {url for url, _ in images} + def _remove_if_extracted(match): url = match.group(2) if match.lastindex >= 2 else match.group(1) - return '' if url in extracted_urls else match.group(0) + return "" if url in extracted_urls else match.group(0) + cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned) cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned) # Clean up leftover blank lines - cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip() - + cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip() + return images, cleaned - + async def send_voice( self, chat_id: str, audio_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send an audio file as a native voice message via the platform API. - + Override in subclasses to send audio as voice bubbles (Telegram) or file attachments (Discord). Default falls back to sending the file path as text. @@ -532,8 +535,8 @@ class BasePlatformAdapter(ABC): self, chat_id: str, video_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send a video natively via the platform API. @@ -550,9 +553,9 @@ class BasePlatformAdapter(ABC): self, chat_id: str, file_path: str, - caption: Optional[str] = None, - file_name: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + file_name: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send a document/file natively via the platform API. @@ -569,8 +572,8 @@ class BasePlatformAdapter(ABC): self, chat_id: str, image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """ Send a local image file natively via the platform API. @@ -585,45 +588,45 @@ class BasePlatformAdapter(ABC): return await self.send(chat_id=chat_id, content=text, reply_to=reply_to) @staticmethod - def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]: + def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]: """ Extract MEDIA: tags and [[audio_as_voice]] directives from response text. - + The TTS tool returns responses like: [[audio_as_voice]] MEDIA:/path/to/audio.ogg - + Args: content: The response text to scan. - + Returns: Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed). """ media = [] cleaned = content - + # Check for [[audio_as_voice]] directive has_voice_tag = "[[audio_as_voice]]" in content cleaned = cleaned.replace("[[audio_as_voice]]", "") - + # Extract MEDIA: tags (path may contain spaces) - media_pattern = r'MEDIA:(\S+)' + media_pattern = r"MEDIA:(\S+)" for match in re.finditer(media_pattern, content): path = match.group(1).strip() if path: media.append((path, has_voice_tag)) - + # Remove MEDIA tags from content if media: - cleaned = re.sub(media_pattern, '', cleaned) - cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip() - + cleaned = re.sub(media_pattern, "", cleaned) + cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip() + return media, cleaned - + async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None: """ Continuously send typing indicator until cancelled. - + Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2 to recover quickly after progress messages interrupt it. """ @@ -633,20 +636,20 @@ class BasePlatformAdapter(ABC): await asyncio.sleep(interval) except asyncio.CancelledError: pass # Normal cancellation when handler completes - + async def handle_message(self, event: MessageEvent) -> None: """ Process an incoming message. - + This method returns quickly by spawning background tasks. This allows new messages to be processed even while an agent is running, enabling interruption support. """ if not self._message_handler: return - + session_key = event.source.chat_id - + # Check if there's already an active handler for this session if session_key in self._active_sessions: # Store this as a pending message - it will interrupt the running agent @@ -655,10 +658,10 @@ class BasePlatformAdapter(ABC): # Signal the interrupt (the processing task checks this) self._active_sessions[session_key].set() return # Don't process now - will be handled after current task finishes - + # Spawn background task to process this message asyncio.create_task(self._process_message_background(event, session_key)) - + @staticmethod def _get_human_delay() -> float: """ @@ -685,35 +688,40 @@ class BasePlatformAdapter(ABC): # Create interrupt event for this session interrupt_event = asyncio.Event() self._active_sessions[session_key] = interrupt_event - + # Start continuous typing indicator (refreshes every 2 seconds) typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id)) - + try: # Call the handler (this can take a while with tool calls) response = await self._message_handler(event) - + # Send response if any if not response: logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id) if response: # Extract MEDIA: tags (from TTS tool) before other processing media_files, response = self.extract_media(response) - + # Extract image URLs and send them as native platform attachments images, text_content = self.extract_images(response) if images: - logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response)) - + logger.info( + "[%s] extract_images found %d image(s) in response (%d chars)", + self.name, + len(images), + len(response), + ) + # Send the text portion first (if any remains after extractions) if text_content: - logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id) - result = await self.send( - chat_id=event.source.chat_id, - content=text_content, - reply_to=event.message_id + logger.info( + "[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id ) - + result = await self.send( + chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id + ) + # Log send failures (don't raise - user already saw tool progress) if not result.success: print(f"[{self.name}] Failed to send response: {result.error}") @@ -721,14 +729,14 @@ class BasePlatformAdapter(ABC): fallback_result = await self.send( chat_id=event.source.chat_id, content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}", - reply_to=event.message_id + reply_to=event.message_id, ) if not fallback_result.success: print(f"[{self.name}] Fallback send also failed: {fallback_result.error}") - + # Human-like pacing delay between text and media human_delay = self._get_human_delay() - + # Send extracted images as native attachments if images: logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images)) @@ -736,7 +744,12 @@ class BasePlatformAdapter(ABC): if human_delay > 0: await asyncio.sleep(human_delay) try: - logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "") + logger.info( + "[%s] Sending image: %s (alt=%s)", + self.name, + image_url[:80], + alt_text[:30] if alt_text else "", + ) # Route animated GIFs through send_animation for proper playback if self._is_animation_url(image_url): img_result = await self.send_animation( @@ -754,11 +767,11 @@ class BasePlatformAdapter(ABC): logger.error("[%s] Failed to send image: %s", self.name, img_result.error) except Exception as img_err: logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True) - + # Send extracted media files — route by file type - _AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'} - _VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'} - _IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'} + _AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"} + _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"} + _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"} for media_path, is_voice in media_files: if human_delay > 0: @@ -790,7 +803,7 @@ class BasePlatformAdapter(ABC): print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}") except Exception as media_err: print(f"[{self.name}] Error sending media: {media_err}") - + # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: pending_event = self._pending_messages.pop(session_key) @@ -806,10 +819,11 @@ class BasePlatformAdapter(ABC): # Process pending message in new background task await self._process_message_background(pending_event, session_key) return # Already cleaned up - + except Exception as e: print(f"[{self.name}] Error handling message: {e}") import traceback + traceback.print_exc() finally: # Stop typing indicator @@ -821,26 +835,26 @@ class BasePlatformAdapter(ABC): # Clean up session tracking if session_key in self._active_sessions: del self._active_sessions[session_key] - + def has_pending_interrupt(self, session_key: str) -> bool: """Check if there's a pending interrupt for a session.""" return session_key in self._active_sessions and self._active_sessions[session_key].is_set() - - def get_pending_message(self, session_key: str) -> Optional[MessageEvent]: + + def get_pending_message(self, session_key: str) -> MessageEvent | None: """Get and clear any pending message for a session.""" return self._pending_messages.pop(session_key, None) - + def build_source( self, chat_id: str, - chat_name: Optional[str] = None, + chat_name: str | None = None, chat_type: str = "dm", - user_id: Optional[str] = None, - user_name: Optional[str] = None, - thread_id: Optional[str] = None, - chat_topic: Optional[str] = None, - user_id_alt: Optional[str] = None, - chat_id_alt: Optional[str] = None, + user_id: str | None = None, + user_name: str | None = None, + thread_id: str | None = None, + chat_topic: str | None = None, + user_id_alt: str | None = None, + chat_id_alt: str | None = None, ) -> SessionSource: """Helper to build a SessionSource for this platform.""" # Normalize empty topic to None @@ -858,30 +872,30 @@ class BasePlatformAdapter(ABC): user_id_alt=user_id_alt, chat_id_alt=chat_id_alt, ) - + @abstractmethod - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """ Get information about a chat/channel. - + Returns dict with at least: - name: Chat name - type: "dm", "group", "channel" """ pass - + def format_message(self, content: str) -> str: """ Format a message for this platform. - + Override in subclasses to handle platform-specific formatting (e.g., Telegram MarkdownV2, Discord markdown). - + Default implementation returns content as-is. """ return content - - def truncate_message(self, content: str, max_length: int = 4096) -> List[str]: + + def truncate_message(self, content: str, max_length: int = 4096) -> list[str]: """ Split a long message into chunks, preserving code block boundaries. @@ -900,14 +914,14 @@ class BasePlatformAdapter(ABC): if len(content) <= max_length: return [content] - INDICATOR_RESERVE = 10 # room for " (XX/XX)" + INDICATOR_RESERVE = 10 # room for " (XX/XX)" FENCE_CLOSE = "\n```" - chunks: List[str] = [] + chunks: list[str] = [] remaining = content # When the previous chunk ended mid-code-block, this holds the # language tag (possibly "") so we can reopen the fence. - carry_lang: Optional[str] = None + carry_lang: str | None = None while remaining: # If we're continuing a code block from the previous chunk, @@ -965,8 +979,6 @@ class BasePlatformAdapter(ABC): # Append chunk indicators when the response spans multiple messages if len(chunks) > 1: total = len(chunks) - chunks = [ - f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks) - ] + chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)] return chunks diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 905e20d6f4..54d9d27b3e 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -10,14 +10,16 @@ Uses discord.py library for: import asyncio import logging import os -from typing import Dict, List, Optional, Any +from typing import Any logger = logging.getLogger(__name__) try: import discord - from discord import Message as DiscordMessage, Intents + from discord import Intents + from discord import Message as DiscordMessage from discord.ext import commands + DISCORD_AVAILABLE = True except ImportError: DISCORD_AVAILABLE = False @@ -28,6 +30,7 @@ except ImportError: import sys from pathlib import Path as _Path + sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig @@ -36,8 +39,8 @@ from gateway.platforms.base import ( MessageEvent, MessageType, SendResult, - cache_image_from_url, cache_audio_from_url, + cache_image_from_url, ) @@ -49,7 +52,7 @@ def check_discord_requirements() -> bool: class DiscordAdapter(BasePlatformAdapter): """ Discord bot adapter. - + Handles: - Receiving messages from servers and DMs - Sending responses with Discord markdown @@ -59,26 +62,26 @@ class DiscordAdapter(BasePlatformAdapter): - Auto-threading for long conversations - Reaction-based feedback """ - + # Discord message limits MAX_MESSAGE_LENGTH = 2000 - + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.DISCORD) - self._client: Optional[commands.Bot] = None + self._client: commands.Bot | None = None self._ready_event = asyncio.Event() self._allowed_user_ids: set = set() # For button approval authorization - + async def connect(self) -> bool: """Connect to Discord and start receiving events.""" if not DISCORD_AVAILABLE: print(f"[{self.name}] discord.py not installed. Run: pip install discord.py") return False - + if not self.config.token: print(f"[{self.name}] No bot token configured") return False - + try: # Set up intents -- members intent needed for username-to-ID resolution intents = Intents.default() @@ -86,30 +89,28 @@ class DiscordAdapter(BasePlatformAdapter): intents.dm_messages = True intents.guild_messages = True intents.members = True - + # Create bot self._client = commands.Bot( command_prefix="!", # Not really used, we handle raw messages intents=intents, ) - + # Parse allowed user entries (may contain usernames or IDs) allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") if allowed_env: - self._allowed_user_ids = { - uid.strip() for uid in allowed_env.split(",") if uid.strip() - } - + self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()} + adapter_self = self # capture for closure - + # Register event handlers @self._client.event async def on_ready(): print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}") - + # Resolve any usernames in the allowed list to numeric IDs await adapter_self._resolve_allowed_usernames() - + # Sync slash commands with Discord try: synced = await adapter_self._client.tree.sync() @@ -117,33 +118,33 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: print(f"[{adapter_self.name}] Slash command sync failed: {e}") adapter_self._ready_event.set() - + @self._client.event async def on_message(message: DiscordMessage): # Ignore bot's own messages if message.author == self._client.user: return await self._handle_message(message) - + # Register slash commands self._register_slash_commands() - + # Start the bot in background asyncio.create_task(self._client.start(self.config.token)) - + # Wait for ready await asyncio.wait_for(self._ready_event.wait(), timeout=30) - + self._running = True return True - - except asyncio.TimeoutError: + + except TimeoutError: print(f"[{self.name}] Timeout waiting for connection") return False except Exception as e: print(f"[{self.name}] Failed to connect: {e}") return False - + async def disconnect(self) -> None: """Disconnect from Discord.""" if self._client: @@ -151,59 +152,55 @@ class DiscordAdapter(BasePlatformAdapter): await self._client.close() except Exception as e: print(f"[{self.name}] Error during disconnect: {e}") - + self._running = False self._client = None self._ready_event.clear() print(f"[{self.name}] Disconnected") - + async def send( - self, - chat_id: str, - content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None ) -> SendResult: """Send a message to a Discord channel.""" if not self._client: return SendResult(success=False, error="Not connected") - + try: # Get the channel channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) - + if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + # Format and split message if needed formatted = self.format_message(content) chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) - + message_ids = [] reference = None - + if reply_to: try: ref_msg = await channel.fetch_message(int(reply_to)) reference = ref_msg except Exception as e: logger.debug("Could not fetch reply-to message: %s", e) - + for i, chunk in enumerate(chunks): msg = await channel.send( content=chunk, reference=reference if i == 0 else None, ) message_ids.append(str(msg.id)) - + return SendResult( success=True, message_id=message_ids[0] if message_ids else None, - raw_response={"message_ids": message_ids} + raw_response={"message_ids": message_ids}, ) - + except Exception as e: return SendResult(success=False, error=str(e)) @@ -223,7 +220,7 @@ class DiscordAdapter(BasePlatformAdapter): msg = await channel.fetch_message(int(message_id)) formatted = self.format_message(content) if len(formatted) > self.MAX_MESSAGE_LENGTH: - formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..." + formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..." await msg.edit(content=formatted) return SendResult(success=True, message_id=message_id) except Exception as e: @@ -233,28 +230,28 @@ class DiscordAdapter(BasePlatformAdapter): self, chat_id: str, audio_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send audio as a Discord file attachment.""" if not self._client: return SendResult(success=False, error="Not connected") - + try: import io - + channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + if not os.path.exists(audio_path): return SendResult(success=False, error=f"Audio file not found: {audio_path}") - + # Determine filename from path filename = os.path.basename(audio_path) - + with open(audio_path, "rb") as f: file = discord.File(io.BytesIO(f.read()), filename=filename) msg = await channel.send( @@ -262,36 +259,36 @@ class DiscordAdapter(BasePlatformAdapter): file=file, ) return SendResult(success=True, message_id=str(msg.id)) - + except Exception as e: print(f"[{self.name}] Failed to send audio: {e}") return await super().send_voice(chat_id, audio_path, caption, reply_to) - + async def send_image_file( self, chat_id: str, image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a local image file natively as a Discord file attachment.""" if not self._client: return SendResult(success=False, error="Not connected") - + try: import io - + channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + if not os.path.exists(image_path): return SendResult(success=False, error=f"Image file not found: {image_path}") - + filename = os.path.basename(image_path) - + with open(image_path, "rb") as f: file = discord.File(io.BytesIO(f.read()), filename=filename) msg = await channel.send( @@ -299,7 +296,7 @@ class DiscordAdapter(BasePlatformAdapter): file=file, ) return SendResult(success=True, message_id=str(msg.id)) - + except Exception as e: print(f"[{self.name}] Failed to send local image: {e}") return await super().send_image_file(chat_id, image_path, caption, reply_to) @@ -308,31 +305,31 @@ class DiscordAdapter(BasePlatformAdapter): self, chat_id: str, image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send an image natively as a Discord file attachment.""" if not self._client: return SendResult(success=False, error="Not connected") - + try: import aiohttp - + channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") - + # Download the image and send as a Discord file attachment # (Discord renders attachments inline, unlike plain URLs) async with aiohttp.ClientSession() as session: async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: if resp.status != 200: raise Exception(f"Failed to download image: HTTP {resp.status}") - + image_data = await resp.read() - + # Determine filename from URL or content type content_type = resp.headers.get("content-type", "image/png") ext = "png" @@ -342,23 +339,24 @@ class DiscordAdapter(BasePlatformAdapter): ext = "gif" elif "webp" in content_type: ext = "webp" - + import io + file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}") - + msg = await channel.send( content=caption if caption else None, file=file, ) return SendResult(success=True, message_id=str(msg.id)) - + except ImportError: print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp") return await super().send_image(chat_id, image_url, caption, reply_to) except Exception as e: print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}") return await super().send_image(chat_id, image_url, caption, reply_to) - + async def send_typing(self, chat_id: str) -> None: """Send typing indicator.""" if self._client: @@ -368,20 +366,20 @@ class DiscordAdapter(BasePlatformAdapter): await channel.typing() except Exception: pass # Ignore typing indicator failures - - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Get information about a Discord channel.""" if not self._client: return {"name": "Unknown", "type": "dm"} - + try: channel = self._client.get_channel(int(chat_id)) if not channel: channel = await self._client.fetch_channel(int(chat_id)) - + if not channel: return {"name": str(chat_id), "type": "dm"} - + # Determine channel type if isinstance(channel, discord.DMChannel): chat_type = "dm" @@ -397,7 +395,7 @@ class DiscordAdapter(BasePlatformAdapter): else: chat_type = "channel" name = getattr(channel, "name", str(chat_id)) - + return { "name": name, "type": chat_type, @@ -406,7 +404,7 @@ class DiscordAdapter(BasePlatformAdapter): } except Exception as e: return {"name": str(chat_id), "type": "dm", "error": str(e)} - + async def _resolve_allowed_usernames(self) -> None: """ Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs. @@ -453,8 +451,10 @@ class DiscordAdapter(BasePlatformAdapter): uid = str(member.id) numeric_ids.add(uid) resolved_count += 1 - matched_name = name_lower if name_lower in to_resolve else ( - display_lower if display_lower in to_resolve else global_lower + matched_name = ( + name_lower + if name_lower in to_resolve + else (display_lower if display_lower in to_resolve else global_lower) ) to_resolve.discard(matched_name) print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})") @@ -474,12 +474,12 @@ class DiscordAdapter(BasePlatformAdapter): def format_message(self, content: str) -> str: """ Format message for Discord. - + Discord uses its own markdown variant. """ # Discord markdown is fairly standard, no special escaping needed return content - + def _register_slash_commands(self) -> None: """Register Discord slash commands on the command tree.""" if not self._client: @@ -694,7 +694,7 @@ class DiscordAdapter(BasePlatformAdapter): chat_name = interaction.channel.name if hasattr(interaction.channel, "guild") and interaction.channel.guild: chat_name = f"{interaction.channel.guild.name} / #{chat_name}" - + # Get channel topic (if available) chat_topic = getattr(interaction.channel, "topic", None) @@ -715,9 +715,7 @@ class DiscordAdapter(BasePlatformAdapter): raw_message=interaction, ) - async def send_exec_approval( - self, chat_id: str, command: str, approval_id: str - ) -> SendResult: + async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult: """ Send a button-based exec approval prompt for a dangerous command. @@ -759,28 +757,28 @@ class DiscordAdapter(BasePlatformAdapter): # bot responds to every message without needing a mention. # DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement # globally (all channels become free-response). Default: "true". - + if not isinstance(message.channel, discord.DMChannel): # Check if this channel is in the free-response list free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "") free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()} channel_id = str(message.channel.id) - + # Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - + is_free_channel = channel_id in free_channels - + if require_mention and not is_free_channel: # Must be @mentioned to respond if self._client.user not in message.mentions: return # Silently ignore messages that don't mention the bot - + # Strip the bot mention from the message text so the agent sees clean input if self._client.user and self._client.user in message.mentions: message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip() message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip() - + # Determine message type msg_type = MessageType.TEXT if message.content.startswith("/"): @@ -798,7 +796,7 @@ class DiscordAdapter(BasePlatformAdapter): else: msg_type = MessageType.DOCUMENT break - + # Determine chat type if isinstance(message.channel, discord.DMChannel): chat_type = "dm" @@ -811,15 +809,15 @@ class DiscordAdapter(BasePlatformAdapter): chat_name = getattr(message.channel, "name", str(message.channel.id)) if hasattr(message.channel, "guild") and message.channel.guild: chat_name = f"{message.channel.guild.name} / #{chat_name}" - + # Get thread ID if in a thread thread_id = None if isinstance(message.channel, discord.Thread): thread_id = str(message.channel.id) - + # Get channel topic (if available - TextChannels have topics, DMs/threads don't) chat_topic = getattr(message.channel, "topic", None) - + # Build source source = self.build_source( chat_id=str(message.channel.id), @@ -830,7 +828,7 @@ class DiscordAdapter(BasePlatformAdapter): thread_id=thread_id, chat_topic=chat_topic, ) - + # Build media URLs -- download image attachments to local cache so the # vision tool can access them reliably (Discord CDN URLs can expire). media_urls = [] @@ -869,7 +867,7 @@ class DiscordAdapter(BasePlatformAdapter): # Other attachments: keep the original URL media_urls.append(att.url) media_types.append(content_type) - + event = MessageEvent( text=message.content, message_type=msg_type, @@ -881,7 +879,7 @@ class DiscordAdapter(BasePlatformAdapter): reply_to_message_id=str(message.reference.message_id) if message.reference else None, timestamp=message.created_at, ) - + await self.handle_message(event) @@ -911,20 +909,14 @@ if DISCORD_AVAILABLE: return True # No allowlist = anyone can approve return str(interaction.user.id) in self.allowed_user_ids - async def _resolve( - self, interaction: discord.Interaction, action: str, color: discord.Color - ): + async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color): """Resolve the approval and update the message.""" if self.resolved: - await interaction.response.send_message( - "This approval has already been resolved~", ephemeral=True - ) + await interaction.response.send_message("This approval has already been resolved~", ephemeral=True) return if not self._check_auth(interaction): - await interaction.response.send_message( - "You're not authorized to approve commands~", ephemeral=True - ) + await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True) return self.resolved = True @@ -944,6 +936,7 @@ if DISCORD_AVAILABLE: # Store the approval decision try: from tools.approval import approve_permanent + if action == "allow_once": pass # One-time approval handled by gateway elif action == "allow_always": @@ -952,21 +945,15 @@ if DISCORD_AVAILABLE: pass @discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green) - async def allow_once( - self, interaction: discord.Interaction, button: discord.ui.Button - ): + async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button): await self._resolve(interaction, "allow_once", discord.Color.green()) @discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple) - async def allow_always( - self, interaction: discord.Interaction, button: discord.ui.Button - ): + async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button): await self._resolve(interaction, "allow_always", discord.Color.blue()) @discord.ui.button(label="Deny", style=discord.ButtonStyle.red) - async def deny( - self, interaction: discord.Interaction, button: discord.ui.Button - ): + async def deny(self, interaction: discord.Interaction, button: discord.ui.Button): await self._resolve(interaction, "deny", discord.Color.red()) async def on_timeout(self): diff --git a/gateway/platforms/homeassistant.py b/gateway/platforms/homeassistant.py index a900ef3b77..2fc155a0e7 100644 --- a/gateway/platforms/homeassistant.py +++ b/gateway/platforms/homeassistant.py @@ -19,10 +19,11 @@ import os import time import uuid from datetime import datetime -from typing import Any, Dict, List, Optional, Set +from typing import Any try: import aiohttp + AIOHTTP_AVAILABLE = True except ImportError: AIOHTTP_AVAILABLE = False @@ -66,10 +67,10 @@ class HomeAssistantAdapter(BasePlatformAdapter): super().__init__(config, Platform.HOMEASSISTANT) # Connection state - self._session: Optional["aiohttp.ClientSession"] = None - self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None - self._rest_session: Optional["aiohttp.ClientSession"] = None - self._listen_task: Optional[asyncio.Task] = None + self._session: aiohttp.ClientSession | None = None + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._rest_session: aiohttp.ClientSession | None = None + self._listen_task: asyncio.Task | None = None self._msg_id: int = 0 # Configuration from extra @@ -80,13 +81,13 @@ class HomeAssistantAdapter(BasePlatformAdapter): self._hass_token: str = token # Event filtering - self._watch_domains: Set[str] = set(extra.get("watch_domains", [])) - self._watch_entities: Set[str] = set(extra.get("watch_entities", [])) - self._ignore_entities: Set[str] = set(extra.get("ignore_entities", [])) + self._watch_domains: set[str] = set(extra.get("watch_domains", [])) + self._watch_entities: set[str] = set(extra.get("watch_entities", [])) + self._ignore_entities: set[str] = set(extra.get("ignore_entities", [])) self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30)) # Cooldown tracking: entity_id -> last_event_timestamp - self._last_event_time: Dict[str, float] = {} + self._last_event_time: dict[str, float] = {} def _next_id(self) -> int: """Return the next WebSocket message ID.""" @@ -141,10 +142,12 @@ class HomeAssistantAdapter(BasePlatformAdapter): return False # Step 2: Send auth - await self._ws.send_json({ - "type": "auth", - "access_token": self._hass_token, - }) + await self._ws.send_json( + { + "type": "auth", + "access_token": self._hass_token, + } + ) # Step 3: Wait for auth_ok msg = await self._ws.receive_json() @@ -155,11 +158,13 @@ class HomeAssistantAdapter(BasePlatformAdapter): # Step 4: Subscribe to state_changed events sub_id = self._next_id() - await self._ws.send_json({ - "id": sub_id, - "type": "subscribe_events", - "event_type": "state_changed", - }) + await self._ws.send_json( + { + "id": sub_id, + "type": "subscribe_events", + "event_type": "state_changed", + } + ) # Verify subscription acknowledgement msg = await self._ws.receive_json() @@ -245,7 +250,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): break - async def _handle_ha_event(self, event: Dict[str, Any]) -> None: + async def _handle_ha_event(self, event: dict[str, Any]) -> None: """Process a state_changed event from Home Assistant.""" event_data = event.get("data", {}) entity_id: str = event_data.get("entity_id", "") @@ -302,9 +307,9 @@ class HomeAssistantAdapter(BasePlatformAdapter): @staticmethod def _format_state_change( entity_id: str, - old_state: Dict[str, Any], - new_state: Dict[str, Any], - ) -> Optional[str]: + old_state: dict[str, Any], + new_state: dict[str, Any], + ) -> str | None: """Convert a state_changed event into a human-readable description.""" if not new_state: return None @@ -331,10 +336,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): if domain == "sensor": unit = new_state.get("attributes", {}).get("unit_of_measurement", "") - return ( - f"[Home Assistant] {friendly_name}: changed from " - f"{old_val}{unit} to {new_val}{unit}" - ) + return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}" if domain == "binary_sensor": return ( @@ -344,22 +346,13 @@ class HomeAssistantAdapter(BasePlatformAdapter): ) if domain in ("light", "switch", "fan"): - return ( - f"[Home Assistant] {friendly_name}: turned " - f"{'on' if new_val == 'on' else 'off'}" - ) + return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}" if domain == "alarm_control_panel": - return ( - f"[Home Assistant] {friendly_name}: alarm state changed from " - f"'{old_val}' to '{new_val}'" - ) + return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'" # Generic fallback - return ( - f"[Home Assistant] {friendly_name} ({entity_id}): " - f"changed from '{old_val}' to '{new_val}'" - ) + return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'" # ------------------------------------------------------------------ # Outbound messaging @@ -369,8 +362,8 @@ class HomeAssistantAdapter(BasePlatformAdapter): self, chat_id: str, content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + reply_to: str | None = None, + metadata: dict[str, Any] | None = None, ) -> SendResult: """Send a notification via HA REST API (persistent_notification.create). @@ -384,7 +377,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): } payload = { "title": "Hermes Agent", - "message": content[:self.MAX_MESSAGE_LENGTH], + "message": content[: self.MAX_MESSAGE_LENGTH], } try: @@ -401,20 +394,22 @@ class HomeAssistantAdapter(BasePlatformAdapter): body = await resp.text() return SendResult(success=False, error=f"HTTP {resp.status}: {body}") else: - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( url, headers=headers, json=payload, timeout=aiohttp.ClientTimeout(total=10), - ) as resp: - if resp.status < 300: - return SendResult(success=True, message_id=uuid.uuid4().hex[:12]) - else: - body = await resp.text() - return SendResult(success=False, error=f"HTTP {resp.status}: {body}") + ) as resp, + ): + if resp.status < 300: + return SendResult(success=True, message_id=uuid.uuid4().hex[:12]) + else: + body = await resp.text() + return SendResult(success=False, error=f"HTTP {resp.status}: {body}") - except asyncio.TimeoutError: + except TimeoutError: return SendResult(success=False, error="Timeout sending notification to HA") except Exception as e: return SendResult(success=False, error=str(e)) @@ -423,7 +418,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): """No typing indicator for Home Assistant.""" pass - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Return basic info about the HA event channel.""" return { "name": "Home Assistant Events", diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 62e7e4b63b..ab550b6568 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -19,9 +19,9 @@ import os import random import re import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Any from urllib.parse import unquote import httpx @@ -32,9 +32,9 @@ from gateway.platforms.base import ( MessageEvent, MessageType, SendResult, - cache_image_from_bytes, cache_audio_from_bytes, cache_document_from_bytes, + cache_image_from_bytes, cache_image_from_url, ) @@ -59,6 +59,7 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") # Helpers # --------------------------------------------------------------------------- + def _redact_phone(phone: str) -> str: """Redact a phone number for logging: +15551234567 -> +155****4567.""" if not phone: @@ -68,7 +69,7 @@ def _redact_phone(phone: str) -> str: return phone[:4] + "****" + phone[-4:] -def _parse_comma_list(value: str) -> List[str]: +def _parse_comma_list(value: str) -> list[str]: """Split a comma-separated string into a list, stripping whitespace.""" return [v.strip() for v in value.split(",") if v.strip()] @@ -110,7 +111,7 @@ def _render_mentions(text: str, mentions: list) -> str: Signal encodes @mentions as the Unicode object replacement character with out-of-band metadata containing the mentioned user's UUID/number. """ - if not mentions or "\uFFFC" not in text: + if not mentions or "\ufffc" not in text: return text # Sort mentions by start position (reverse) to replace from end to start # so indices don't shift as we replace @@ -121,7 +122,7 @@ def _render_mentions(text: str, mentions: list) -> str: # Use the mention's number or UUID as the replacement identifier = mention.get("number") or mention.get("uuid") or "user" replacement = f"@{identifier}" - text = text[:start] + replacement + text[start + length:] + text = text[:start] + replacement + text[start + length :] return text @@ -134,6 +135,7 @@ def check_signal_requirements() -> bool: # Signal Adapter # --------------------------------------------------------------------------- + class SignalAdapter(BasePlatformAdapter): """Signal messenger adapter using signal-cli HTTP daemon.""" @@ -152,22 +154,25 @@ class SignalAdapter(BasePlatformAdapter): self.group_allow_from = set(_parse_comma_list(group_allowed_str)) # HTTP client - self.client: Optional[httpx.AsyncClient] = None + self.client: httpx.AsyncClient | None = None # Background tasks - self._sse_task: Optional[asyncio.Task] = None - self._health_monitor_task: Optional[asyncio.Task] = None - self._typing_tasks: Dict[str, asyncio.Task] = {} + self._sse_task: asyncio.Task | None = None + self._health_monitor_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} self._running = False self._last_sse_activity = 0.0 - self._sse_response: Optional[httpx.Response] = None + self._sse_response: httpx.Response | None = None # Normalize account for self-message filtering self._account_normalized = self.account.strip() - logger.info("Signal adapter initialized: url=%s account=%s groups=%s", - self.http_url, _redact_phone(self.account), - "enabled" if self.group_allow_from else "disabled") + logger.info( + "Signal adapter initialized: url=%s account=%s groups=%s", + self.http_url, + _redact_phone(self.account), + "enabled" if self.group_allow_from else "disabled", + ) # ------------------------------------------------------------------ # Lifecycle @@ -241,7 +246,8 @@ class SignalAdapter(BasePlatformAdapter): try: logger.debug("Signal SSE: connecting to %s", url) async with self.client.stream( - "GET", url, + "GET", + url, headers={"Accept": "text/event-stream"}, timeout=None, ) as response: @@ -306,9 +312,7 @@ class SignalAdapter(BasePlatformAdapter): if elapsed > HEALTH_CHECK_STALE_THRESHOLD: logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed) try: - resp = await self.client.get( - f"{self.http_url}/api/v1/check", timeout=10.0 - ) + resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0) if resp.status_code == 200: # Daemon is alive but SSE is idle — update activity to # avoid repeated warnings (connection may just be quiet) @@ -345,11 +349,7 @@ class SignalAdapter(BasePlatformAdapter): return # Extract sender info - sender = ( - envelope_data.get("sourceNumber") - or envelope_data.get("sourceUuid") - or envelope_data.get("source") - ) + sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source") sender_name = envelope_data.get("sourceName", "") sender_uuid = envelope_data.get("sourceUuid", "") @@ -367,10 +367,7 @@ class SignalAdapter(BasePlatformAdapter): # Get data message — also check editMessage (edited messages contain # their updated dataMessage inside editMessage.dataMessage) - data_message = ( - envelope_data.get("dataMessage") - or (envelope_data.get("editMessage") or {}).get("dataMessage") - ) + data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage") if not data_message: return @@ -451,11 +448,11 @@ class SignalAdapter(BasePlatformAdapter): ts_ms = envelope_data.get("timestamp", 0) if ts_ms: try: - timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) + timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC) except (ValueError, OSError): - timestamp = datetime.now(tz=timezone.utc) + timestamp = datetime.now(tz=UTC) else: - timestamp = datetime.now(tz=timezone.utc) + timestamp = datetime.now(tz=UTC) # Build and dispatch event event = MessageEvent( @@ -468,8 +465,7 @@ class SignalAdapter(BasePlatformAdapter): timestamp=timestamp, ) - logger.debug("Signal: message from %s in %s: %s", - _redact_phone(sender), chat_id[:20], (text or "")[:50]) + logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50]) await self.handle_message(event) @@ -479,10 +475,13 @@ class SignalAdapter(BasePlatformAdapter): async def _fetch_attachment(self, attachment_id: str) -> tuple: """Fetch an attachment via JSON-RPC and cache it. Returns (path, ext).""" - result = await self._rpc("getAttachment", { - "account": self.account, - "attachmentId": attachment_id, - }) + result = await self._rpc( + "getAttachment", + { + "account": self.account, + "attachmentId": attachment_id, + }, + ) if not result: return None, "" @@ -547,13 +546,13 @@ class SignalAdapter(BasePlatformAdapter): self, chat_id: str, text: str, - reply_to_message_id: Optional[str] = None, + reply_to_message_id: str | None = None, **kwargs, ) -> SendResult: """Send a text message.""" await self._stop_typing_indicator(chat_id) - params: Dict[str, Any] = { + params: dict[str, Any] = { "account": self.account, "message": text, } @@ -571,7 +570,7 @@ class SignalAdapter(BasePlatformAdapter): async def send_typing(self, chat_id: str) -> None: """Send a typing indicator.""" - params: Dict[str, Any] = { + params: dict[str, Any] = { "account": self.account, } @@ -586,7 +585,7 @@ class SignalAdapter(BasePlatformAdapter): self, chat_id: str, image_url: str, - caption: Optional[str] = None, + caption: str | None = None, **kwargs, ) -> SendResult: """Send an image. Supports http(s):// and file:// URLs.""" @@ -611,7 +610,7 @@ class SignalAdapter(BasePlatformAdapter): if file_size > SIGNAL_MAX_ATTACHMENT_SIZE: return SendResult(success=False, error=f"Image too large ({file_size} bytes)") - params: Dict[str, Any] = { + params: dict[str, Any] = { "account": self.account, "message": caption or "", "attachments": [file_path], @@ -631,8 +630,8 @@ class SignalAdapter(BasePlatformAdapter): self, chat_id: str, file_path: str, - caption: Optional[str] = None, - filename: Optional[str] = None, + caption: str | None = None, + filename: str | None = None, **kwargs, ) -> SendResult: """Send a document/file attachment.""" @@ -641,7 +640,7 @@ class SignalAdapter(BasePlatformAdapter): if not Path(file_path).exists(): return SendResult(success=False, error="File not found") - params: Dict[str, Any] = { + params: dict[str, Any] = { "account": self.account, "message": caption or "", "attachments": [file_path], @@ -690,7 +689,7 @@ class SignalAdapter(BasePlatformAdapter): # Chat Info # ------------------------------------------------------------------ - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Get information about a chat/contact.""" if chat_id.startswith("group:"): return { @@ -700,10 +699,13 @@ class SignalAdapter(BasePlatformAdapter): } # Try to resolve contact name - result = await self._rpc("getContact", { - "account": self.account, - "contactAddress": chat_id, - }) + result = await self._rpc( + "getContact", + { + "account": self.account, + "contactAddress": chat_id, + }, + ) name = chat_id if result and isinstance(result, dict): diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 020843d3ac..18a6b5fa50 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -11,12 +11,13 @@ Uses slack-bolt (Python) with Socket Mode for: import asyncio import os import re -from typing import Dict, List, Optional, Any +from typing import Any try: - from slack_bolt.async_app import AsyncApp from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler + from slack_bolt.async_app import AsyncApp from slack_sdk.web.async_client import AsyncWebClient + SLACK_AVAILABLE = True except ImportError: SLACK_AVAILABLE = False @@ -26,18 +27,17 @@ except ImportError: import sys from pathlib import Path as _Path + sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( + SUPPORTED_DOCUMENT_TYPES, BasePlatformAdapter, MessageEvent, MessageType, SendResult, - SUPPORTED_DOCUMENT_TYPES, cache_document_from_bytes, - cache_image_from_url, - cache_audio_from_url, ) @@ -66,9 +66,9 @@ class SlackAdapter(BasePlatformAdapter): def __init__(self, config: PlatformConfig): super().__init__(config, Platform.SLACK) - self._app: Optional[AsyncApp] = None - self._handler: Optional[AsyncSocketModeHandler] = None - self._bot_user_id: Optional[str] = None + self._app: AsyncApp | None = None + self._handler: AsyncSocketModeHandler | None = None + self._bot_user_id: str | None = None async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -135,8 +135,8 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + reply_to: str | None = None, + metadata: dict[str, Any] | None = None, ) -> SendResult: """Send a message to a Slack channel or DM.""" if not self._app: @@ -193,8 +193,8 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a local image file to Slack by uploading it.""" if not self._app: @@ -202,6 +202,7 @@ class SlackAdapter(BasePlatformAdapter): try: import os + if not os.path.exists(image_path): return SendResult(success=False, error=f"Image file not found: {image_path}") @@ -222,8 +223,8 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send an image to Slack by uploading the URL as a file.""" if not self._app: @@ -247,7 +248,7 @@ class SlackAdapter(BasePlatformAdapter): return SendResult(success=True, raw_response=result) - except Exception as e: + except Exception: # Fall back to sending the URL as text text = f"{caption}\n{image_url}" if caption else image_url return await self.send(chat_id=chat_id, content=text, reply_to=reply_to) @@ -256,8 +257,8 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, audio_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send an audio file to Slack.""" if not self._app: @@ -280,8 +281,8 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, video_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a video file to Slack.""" if not self._app: @@ -308,9 +309,9 @@ class SlackAdapter(BasePlatformAdapter): self, chat_id: str, file_path: str, - caption: Optional[str] = None, - file_name: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + file_name: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a document/file attachment to Slack.""" if not self._app: @@ -335,7 +336,7 @@ class SlackAdapter(BasePlatformAdapter): print(f"[{self.name}] Failed to send document: {e}") return await super().send_document(chat_id, file_path, caption, file_name, reply_to) - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Get information about a Slack channel.""" if not self._app: return {"name": chat_id, "type": "unknown"} @@ -442,9 +443,7 @@ class SlackAdapter(BasePlatformAdapter): # Download and cache raw_bytes = await self._download_slack_file_bytes(url) - cached_path = cache_document_from_bytes( - raw_bytes, original_filename or f"document{ext}" - ) + cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}") doc_mime = SUPPORTED_DOCUMENT_TYPES[ext] media_urls.append(cached_path) media_types.append(doc_mime) @@ -457,7 +456,7 @@ class SlackAdapter(BasePlatformAdapter): try: text_content = raw_bytes.decode("utf-8") display_name = original_filename or f"document{ext}" - display_name = re.sub(r'[^\w.\- ]', '_', display_name) + display_name = re.sub(r"[^\w.\- ]", "_", display_name) injection = f"[Content of {display_name}]:\n{text_content}" if text: text = f"{injection}\n\n{text}" @@ -499,16 +498,20 @@ class SlackAdapter(BasePlatformAdapter): # Map subcommands to gateway commands subcommand_map = { - "new": "/reset", "reset": "/reset", - "status": "/status", "stop": "/stop", + "new": "/reset", + "reset": "/reset", + "status": "/status", + "stop": "/stop", "help": "/help", - "model": "/model", "personality": "/personality", - "retry": "/retry", "undo": "/undo", + "model": "/model", + "personality": "/personality", + "retry": "/retry", + "undo": "/undo", } first_word = text.split()[0] if text else "" if first_word in subcommand_map: # Preserve arguments after the subcommand - rest = text[len(first_word):].strip() + rest = text[len(first_word) :].strip() text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word] elif text: pass # Treat as a regular question @@ -544,9 +547,11 @@ class SlackAdapter(BasePlatformAdapter): if audio: from gateway.platforms.base import cache_audio_from_bytes + return cache_audio_from_bytes(response.content, ext) else: from gateway.platforms.base import cache_image_from_bytes + return cache_image_from_bytes(response.content, ext) async def _download_slack_file_bytes(self, url: str) -> bytes: diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 4371bfdbde..86bb84c876 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -7,24 +7,26 @@ Uses python-telegram-bot library for: - Handling media and commands """ -import asyncio import logging import os import re -from typing import Dict, List, Optional, Any +from typing import Any logger = logging.getLogger(__name__) try: - from telegram import Update, Bot, Message + from telegram import Bot, Message, Update + from telegram.constants import ChatType, ParseMode from telegram.ext import ( Application, CommandHandler, - MessageHandler as TelegramMessageHandler, ContextTypes, filters, ) - from telegram.constants import ParseMode, ChatType + from telegram.ext import ( + MessageHandler as TelegramMessageHandler, + ) + TELEGRAM_AVAILABLE = True except ImportError: TELEGRAM_AVAILABLE = False @@ -42,22 +44,24 @@ except ImportError: # don't crash during class definition when the library isn't installed. class _MockContextTypes: DEFAULT_TYPE = Any + ContextTypes = _MockContextTypes import sys from pathlib import Path as _Path + sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( + SUPPORTED_DOCUMENT_TYPES, BasePlatformAdapter, MessageEvent, MessageType, SendResult, - cache_image_from_bytes, cache_audio_from_bytes, cache_document_from_bytes, - SUPPORTED_DOCUMENT_TYPES, + cache_image_from_bytes, ) @@ -68,12 +72,12 @@ def check_telegram_requirements() -> bool: # Matches every character that MarkdownV2 requires to be backslash-escaped # when it appears outside a code span or fenced code block. -_MDV2_ESCAPE_RE = re.compile(r'([_*\[\]()~`>#\+\-=|{}.!\\])') +_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])") def _escape_mdv2(text: str) -> str: """Escape Telegram MarkdownV2 special characters with a preceding backslash.""" - return _MDV2_ESCAPE_RE.sub(r'\\\1', text) + return _MDV2_ESCAPE_RE.sub(r"\\\1", text) def _strip_mdv2(text: str) -> str: @@ -83,103 +87,108 @@ def _strip_mdv2(text: str) -> str: doesn't show stray asterisks from header/bold conversion. """ # Remove escape backslashes before special characters - cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text) + cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text) # Remove MarkdownV2 bold markers that format_message converted from **bold** - cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned) + cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned) return cleaned class TelegramAdapter(BasePlatformAdapter): """ Telegram bot adapter. - + Handles: - Receiving messages from users and groups - Sending responses with Telegram markdown - Forum topics (thread_id support) - Media messages """ - + # Telegram message limits MAX_MESSAGE_LENGTH = 4096 - + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.TELEGRAM) - self._app: Optional[Application] = None - self._bot: Optional[Bot] = None - + self._app: Application | None = None + self._bot: Bot | None = None + async def connect(self) -> bool: """Connect to Telegram and start polling for updates.""" if not TELEGRAM_AVAILABLE: print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot") return False - + if not self.config.token: print(f"[{self.name}] No bot token configured") return False - + try: # Build the application self._app = Application.builder().token(self.config.token).build() self._bot = self._app.bot - + # Register handlers - self._app.add_handler(TelegramMessageHandler( - filters.TEXT & ~filters.COMMAND, - self._handle_text_message - )) - self._app.add_handler(TelegramMessageHandler( - filters.COMMAND, - self._handle_command - )) - self._app.add_handler(TelegramMessageHandler( - filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), - self._handle_location_message - )) - self._app.add_handler(TelegramMessageHandler( - filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL, - self._handle_media_message - )) - + self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message)) + self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command)) + self._app.add_handler( + TelegramMessageHandler( + filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message + ) + ) + self._app.add_handler( + TelegramMessageHandler( + filters.PHOTO + | filters.VIDEO + | filters.AUDIO + | filters.VOICE + | filters.Document.ALL + | filters.Sticker.ALL, + self._handle_media_message, + ) + ) + # Start polling in background await self._app.initialize() await self._app.start() await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES) - + # Register bot commands so Telegram shows a hint menu when users type / try: from telegram import BotCommand - await self._bot.set_my_commands([ - BotCommand("new", "Start a new conversation"), - BotCommand("reset", "Reset conversation history"), - BotCommand("model", "Show or change the model"), - BotCommand("personality", "Set a personality"), - BotCommand("retry", "Retry your last message"), - BotCommand("undo", "Remove the last exchange"), - BotCommand("status", "Show session info"), - BotCommand("stop", "Stop the running agent"), - BotCommand("sethome", "Set this chat as the home channel"), - BotCommand("compress", "Compress conversation context"), - BotCommand("title", "Set or show the session title"), - BotCommand("resume", "Resume a previously-named session"), - BotCommand("usage", "Show token usage for this session"), - BotCommand("provider", "Show available providers"), - BotCommand("insights", "Show usage insights and analytics"), - BotCommand("update", "Update Hermes to the latest version"), - BotCommand("reload_mcp", "Reload MCP servers from config"), - BotCommand("help", "Show available commands"), - ]) + + await self._bot.set_my_commands( + [ + BotCommand("new", "Start a new conversation"), + BotCommand("reset", "Reset conversation history"), + BotCommand("model", "Show or change the model"), + BotCommand("personality", "Set a personality"), + BotCommand("retry", "Retry your last message"), + BotCommand("undo", "Remove the last exchange"), + BotCommand("status", "Show session info"), + BotCommand("stop", "Stop the running agent"), + BotCommand("sethome", "Set this chat as the home channel"), + BotCommand("compress", "Compress conversation context"), + BotCommand("title", "Set or show the session title"), + BotCommand("resume", "Resume a previously-named session"), + BotCommand("usage", "Show token usage for this session"), + BotCommand("provider", "Show available providers"), + BotCommand("insights", "Show usage insights and analytics"), + BotCommand("update", "Update Hermes to the latest version"), + BotCommand("reload_mcp", "Reload MCP servers from config"), + BotCommand("help", "Show available commands"), + ] + ) except Exception as e: print(f"[{self.name}] Could not register command menu: {e}") - + self._running = True print(f"[{self.name}] Connected and polling for updates") return True - + except Exception as e: print(f"[{self.name}] Failed to connect: {e}") return False - + async def disconnect(self) -> None: """Stop polling and disconnect.""" if self._app: @@ -189,31 +198,27 @@ class TelegramAdapter(BasePlatformAdapter): await self._app.shutdown() except Exception as e: print(f"[{self.name}] Error during disconnect: {e}") - + self._running = False self._app = None self._bot = None print(f"[{self.name}] Disconnected") - + async def send( - self, - chat_id: str, - content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None ) -> SendResult: """Send a message to a Telegram chat.""" if not self._bot: return SendResult(success=False, error="Not connected") - + try: # Format and split message if needed formatted = self.format_message(content) chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) - + message_ids = [] thread_id = metadata.get("thread_id") if metadata else None - + for i, chunk in enumerate(chunks): # Try Markdown first, fall back to plain text if it fails try: @@ -227,7 +232,9 @@ class TelegramAdapter(BasePlatformAdapter): except Exception as md_error: # Markdown parsing failed, try plain text if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower(): - logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error) + logger.warning( + "[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error + ) # Strip MDV2 escape backslashes so the user doesn't # see raw backslashes littered through the message. plain_chunk = _strip_mdv2(chunk) @@ -241,13 +248,13 @@ class TelegramAdapter(BasePlatformAdapter): else: raise # Re-raise if not a parse error message_ids.append(str(msg.message_id)) - + return SendResult( success=True, message_id=message_ids[0] if message_ids else None, - raw_response={"message_ids": message_ids} + raw_response={"message_ids": message_ids}, ) - + except Exception as e: return SendResult(success=False, error=str(e)) @@ -284,18 +291,19 @@ class TelegramAdapter(BasePlatformAdapter): self, chat_id: str, audio_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send audio as a native Telegram voice message or audio file.""" if not self._bot: return SendResult(success=False, error="Not connected") - + try: import os + if not os.path.exists(audio_path): return SendResult(success=False, error=f"Audio file not found: {audio_path}") - + with open(audio_path, "rb") as audio_file: # .ogg files -> send as voice (round playable bubble) if audio_path.endswith(".ogg") or audio_path.endswith(".opus"): @@ -317,23 +325,24 @@ class TelegramAdapter(BasePlatformAdapter): except Exception as e: print(f"[{self.name}] Failed to send voice/audio: {e}") return await super().send_voice(chat_id, audio_path, caption, reply_to) - + async def send_image_file( self, chat_id: str, image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a local image file natively as a Telegram photo.""" if not self._bot: return SendResult(success=False, error="Not connected") - + try: import os + if not os.path.exists(image_path): return SendResult(success=False, error=f"Image file not found: {image_path}") - + with open(image_path, "rb") as image_file: msg = await self._bot.send_photo( chat_id=int(chat_id), @@ -350,17 +359,17 @@ class TelegramAdapter(BasePlatformAdapter): self, chat_id: str, image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send an image natively as a Telegram photo. - + Tries URL-based send first (fast, works for <5MB images). Falls back to downloading and uploading as file (supports up to 10MB). """ if not self._bot: return SendResult(success=False, error="Not connected") - + try: # Telegram can send photos directly from URLs (up to ~5MB) msg = await self._bot.send_photo( @@ -375,11 +384,12 @@ class TelegramAdapter(BasePlatformAdapter): # Fallback: download and upload as file (supports up to 10MB) try: import httpx + async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.get(image_url) resp.raise_for_status() image_data = resp.content - + msg = await self._bot.send_photo( chat_id=int(chat_id), photo=image_data, @@ -391,18 +401,18 @@ class TelegramAdapter(BasePlatformAdapter): logger.error("[%s] File upload send_photo also failed: %s", self.name, e2) # Final fallback: send URL as text return await super().send_image(chat_id, image_url, caption, reply_to) - + async def send_animation( self, chat_id: str, animation_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send an animated GIF natively as a Telegram animation (auto-plays inline).""" if not self._bot: return SendResult(success=False, error="Not connected") - + try: msg = await self._bot.send_animation( chat_id=int(chat_id), @@ -420,21 +430,18 @@ class TelegramAdapter(BasePlatformAdapter): """Send typing indicator.""" if self._bot: try: - await self._bot.send_chat_action( - chat_id=int(chat_id), - action="typing" - ) + await self._bot.send_chat_action(chat_id=int(chat_id), action="typing") except Exception: pass # Ignore typing indicator failures - - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Get information about a Telegram chat.""" if not self._bot: return {"name": "Unknown", "type": "dm"} - + try: chat = await self._bot.get_chat(int(chat_id)) - + chat_type = "dm" if chat.type == ChatType.GROUP: chat_type = "group" @@ -444,7 +451,7 @@ class TelegramAdapter(BasePlatformAdapter): chat_type = "forum" elif chat.type == ChatType.CHANNEL: chat_type = "channel" - + return { "name": chat.title or chat.full_name or str(chat_id), "type": chat_type, @@ -453,7 +460,7 @@ class TelegramAdapter(BasePlatformAdapter): } except Exception as e: return {"name": str(chat_id), "type": "dm", "error": str(e)} - + def format_message(self, content: str) -> str: """ Convert standard markdown to Telegram MarkdownV2 format. @@ -480,38 +487,36 @@ class TelegramAdapter(BasePlatformAdapter): # 1) Protect fenced code blocks (``` ... ```) text = re.sub( - r'(```(?:[^\n]*\n)?[\s\S]*?```)', + r"(```(?:[^\n]*\n)?[\s\S]*?```)", lambda m: _ph(m.group(0)), text, ) # 2) Protect inline code (`...`) - text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text) + text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text) # 3) Convert markdown links – escape the display text; inside the URL # only ')' and '\' need escaping per the MarkdownV2 spec. def _convert_link(m): display = _escape_mdv2(m.group(1)) - url = m.group(2).replace('\\', '\\\\').replace(')', '\\)') - return _ph(f'[{display}]({url})') + url = m.group(2).replace("\\", "\\\\").replace(")", "\\)") + return _ph(f"[{display}]({url})") - text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', _convert_link, text) + text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text) # 4) Convert markdown headers (## Title) → bold *Title* def _convert_header(m): inner = m.group(1).strip() # Strip redundant bold markers that may appear inside a header - inner = re.sub(r'\*\*(.+?)\*\*', r'\1', inner) - return _ph(f'*{_escape_mdv2(inner)}*') + inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner) + return _ph(f"*{_escape_mdv2(inner)}*") - text = re.sub( - r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE - ) + text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE) # 5) Convert bold: **text** → *text* (MarkdownV2 bold) text = re.sub( - r'\*\*(.+?)\*\*', - lambda m: _ph(f'*{_escape_mdv2(m.group(1))}*'), + r"\*\*(.+?)\*\*", + lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"), text, ) @@ -519,8 +524,8 @@ class TelegramAdapter(BasePlatformAdapter): # [^*\n]+ prevents matching across newlines (which would corrupt # bullet lists using * markers and multi-line content). text = re.sub( - r'\*([^*\n]+)\*', - lambda m: _ph(f'_{_escape_mdv2(m.group(1))}_'), + r"\*([^*\n]+)\*", + lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"), text, ) @@ -533,23 +538,23 @@ class TelegramAdapter(BasePlatformAdapter): text = text.replace(key, placeholders[key]) return text - + async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming text messages.""" if not update.message or not update.message.text: return - + event = self._build_message_event(update.message, MessageType.TEXT) await self.handle_message(event) - + async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming command messages.""" if not update.message or not update.message.text: return - + event = self._build_message_event(update.message, MessageType.COMMAND) await self.handle_message(event) - + async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming location/venue pin messages.""" if not update.message: @@ -589,9 +594,9 @@ class TelegramAdapter(BasePlatformAdapter): """Handle incoming media messages, downloading images to local cache.""" if not update.message: return - + msg = update.message - + # Determine media type if msg.sticker: msg_type = MessageType.STICKER @@ -607,19 +612,19 @@ class TelegramAdapter(BasePlatformAdapter): msg_type = MessageType.DOCUMENT else: msg_type = MessageType.DOCUMENT - + event = self._build_message_event(msg, msg_type) - + # Add caption as text if msg.caption: event.text = msg.caption - + # Handle stickers: describe via vision tool with caching if msg.sticker: await self._handle_sticker(msg, event) await self.handle_message(event) return - + # Download photo to local image cache so the vision tool can access it # even after Telegram's ephemeral file URLs expire (~1 hour). if msg.photo: @@ -643,7 +648,7 @@ class TelegramAdapter(BasePlatformAdapter): print(f"[Telegram] Cached user photo: {cached_path}", flush=True) except Exception as e: print(f"[Telegram] Failed to cache photo: {e}", flush=True) - + # Download voice/audio messages to cache for STT transcription if msg.voice: try: @@ -685,10 +690,7 @@ class TelegramAdapter(BasePlatformAdapter): # Check if supported if ext not in SUPPORTED_DOCUMENT_TYPES: supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys())) - event.text = ( - f"Unsupported document type '{ext or 'unknown'}'. " - f"Supported types: {supported_list}" - ) + event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}" print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True) await self.handle_message(event) return @@ -696,10 +698,7 @@ class TelegramAdapter(BasePlatformAdapter): # Check file size (Telegram Bot API limit: 20 MB) MAX_DOC_BYTES = 20 * 1024 * 1024 if not doc.file_size or doc.file_size > MAX_DOC_BYTES: - event.text = ( - "The document is too large or its size could not be verified. " - "Maximum: 20 MB." - ) + event.text = "The document is too large or its size could not be verified. Maximum: 20 MB." print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True) await self.handle_message(event) return @@ -720,20 +719,20 @@ class TelegramAdapter(BasePlatformAdapter): try: text_content = raw_bytes.decode("utf-8") display_name = original_filename or f"document{ext}" - display_name = re.sub(r'[^\w.\- ]', '_', display_name) + display_name = re.sub(r"[^\w.\- ]", "_", display_name) injection = f"[Content of {display_name}]:\n{text_content}" if event.text: event.text = f"{injection}\n\n{event.text}" else: event.text = injection except UnicodeDecodeError: - print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True) + print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True) except Exception as e: print(f"[Telegram] Failed to cache document: {e}", flush=True) await self.handle_message(event) - + async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None: """ Describe a Telegram sticker via vision analysis, with caching. @@ -743,11 +742,11 @@ class TelegramAdapter(BasePlatformAdapter): a placeholder noting the emoji. """ from gateway.sticker_cache import ( - get_cached_description, - cache_sticker_description, - build_sticker_injection, - build_animated_sticker_injection, STICKER_VISION_PROMPT, + build_animated_sticker_injection, + build_sticker_injection, + cache_sticker_description, + get_cached_description, ) sticker = msg.sticker @@ -775,9 +774,10 @@ class TelegramAdapter(BasePlatformAdapter): cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp") print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True) - from tools.vision_tools import vision_analyze_tool import json as _json + from tools.vision_tools import vision_analyze_tool + result_json = await vision_analyze_tool( image_url=cached_path, user_prompt=STICKER_VISION_PROMPT, @@ -792,27 +792,29 @@ class TelegramAdapter(BasePlatformAdapter): # Vision failed -- use emoji as fallback event.text = build_sticker_injection( f"a sticker with emoji {emoji}" if emoji else "a sticker", - emoji, set_name, + emoji, + set_name, ) except Exception as e: print(f"[Telegram] Sticker analysis error: {e}", flush=True) event.text = build_sticker_injection( f"a sticker with emoji {emoji}" if emoji else "a sticker", - emoji, set_name, + emoji, + set_name, ) def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent: """Build a MessageEvent from a Telegram message.""" chat = message.chat user = message.from_user - + # Determine chat type chat_type = "dm" if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP): chat_type = "group" elif chat.type == ChatType.CHANNEL: chat_type = "channel" - + # Build source source = self.build_source( chat_id=str(chat.id), @@ -822,7 +824,7 @@ class TelegramAdapter(BasePlatformAdapter): user_name=user.full_name if user else None, thread_id=str(message.message_thread_id) if message.message_thread_id else None, ) - + return MessageEvent( text=message.text or "", message_type=msg_type, diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index 285a89eef2..0d7f30525d 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -16,7 +16,6 @@ with different backends via a bridge pattern. """ import asyncio -import json import logging import os import platform @@ -24,7 +23,7 @@ import subprocess _IS_WINDOWS = platform.system() == "Windows" from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Any logger = logging.getLogger(__name__) @@ -36,7 +35,9 @@ def _kill_port_process(port: int) -> None: # Use netstat to find the PID bound to this port, then taskkill result = subprocess.run( ["netstat", "-ano", "-p", "TCP"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) for line in result.stdout.splitlines(): parts = line.split() @@ -46,24 +47,29 @@ def _kill_port_process(port: int) -> None: try: subprocess.run( ["taskkill", "/PID", parts[4], "/F"], - capture_output=True, timeout=5, + capture_output=True, + timeout=5, ) except subprocess.SubprocessError: pass else: result = subprocess.run( ["fuser", f"{port}/tcp"], - capture_output=True, timeout=5, + capture_output=True, + timeout=5, ) if result.returncode == 0: subprocess.run( ["fuser", "-k", f"{port}/tcp"], - capture_output=True, timeout=5, + capture_output=True, + timeout=5, ) except Exception: pass + import sys + sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig @@ -72,25 +78,20 @@ from gateway.platforms.base import ( MessageEvent, MessageType, SendResult, - cache_image_from_url, cache_audio_from_url, + cache_image_from_url, ) def check_whatsapp_requirements() -> bool: """ Check if WhatsApp dependencies are available. - + WhatsApp requires a Node.js bridge for most implementations. """ # Check for Node.js try: - result = subprocess.run( - ["node", "--version"], - capture_output=True, - text=True, - timeout=5 - ) + result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5) return result.returncode == 0 except Exception: return False @@ -99,62 +100,61 @@ def check_whatsapp_requirements() -> bool: class WhatsAppAdapter(BasePlatformAdapter): """ WhatsApp adapter. - + This implementation uses a simple HTTP bridge pattern where: 1. A Node.js process runs the WhatsApp Web client 2. Messages are forwarded via HTTP/IPC to this Python adapter 3. Responses are sent back through the bridge - + The actual Node.js bridge implementation can vary: - whatsapp-web.js based - Baileys based - Business API based - + Configuration: - bridge_script: Path to the Node.js bridge script - bridge_port: Port for HTTP communication (default: 3000) - session_path: Path to store WhatsApp session data """ - + # WhatsApp message limits MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages - + # Default bridge location relative to the hermes-agent install _DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge" def __init__(self, config: PlatformConfig): super().__init__(config, Platform.WHATSAPP) - self._bridge_process: Optional[subprocess.Popen] = None + self._bridge_process: subprocess.Popen | None = None self._bridge_port: int = config.extra.get("bridge_port", 3000) - self._bridge_script: Optional[str] = config.extra.get( + self._bridge_script: str | None = config.extra.get( "bridge_script", str(self._DEFAULT_BRIDGE_DIR / "bridge.js"), ) - self._session_path: Path = Path(config.extra.get( - "session_path", - Path.home() / ".hermes" / "whatsapp" / "session" - )) + self._session_path: Path = Path( + config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session") + ) self._message_queue: asyncio.Queue = asyncio.Queue() self._bridge_log_fh = None - self._bridge_log: Optional[Path] = None - + self._bridge_log: Path | None = None + async def connect(self) -> bool: """ Start the WhatsApp bridge. - + This launches the Node.js bridge process and waits for it to be ready. """ if not check_whatsapp_requirements(): logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name) return False - + bridge_path = Path(self._bridge_script) if not bridge_path.exists(): logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path) return False - + logger.info("[%s] Bridge found at %s", self.name, bridge_path) - + # Auto-install npm dependencies if node_modules doesn't exist bridge_dir = bridge_path.parent if not (bridge_dir / "node_modules").exists(): @@ -174,16 +174,17 @@ class WhatsAppAdapter(BasePlatformAdapter): except Exception as e: print(f"[{self.name}] Failed to install dependencies: {e}") return False - + try: # Ensure session directory exists self._session_path.mkdir(parents=True, exist_ok=True) - + # Kill any orphaned bridge from a previous gateway run _kill_port_process(self._bridge_port) import time + time.sleep(1) - + # Start the bridge process in its own process group. # Route output to a log file so QR codes, errors, and reconnection # messages are preserved for troubleshooting. @@ -195,19 +196,23 @@ class WhatsAppAdapter(BasePlatformAdapter): [ "node", str(bridge_path), - "--port", str(self._bridge_port), - "--session", str(self._session_path), - "--mode", whatsapp_mode, + "--port", + str(self._bridge_port), + "--session", + str(self._session_path), + "--mode", + whatsapp_mode, ], stdout=bridge_log_fh, stderr=bridge_log_fh, preexec_fn=None if _IS_WINDOWS else os.setsid, ) - + # Wait for the bridge to connect to WhatsApp. # Phase 1: wait for the HTTP server to come up (up to 15s). # Phase 2: wait for WhatsApp status: connected (up to 15s more). import aiohttp + http_ready = False data = {} for attempt in range(15): @@ -218,17 +223,18 @@ class WhatsAppAdapter(BasePlatformAdapter): self._close_bridge_log() return False try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://localhost:{self._bridge_port}/health", - timeout=aiohttp.ClientTimeout(total=2) - ) as resp: - if resp.status == 200: - http_ready = True - data = await resp.json() - if data.get("status") == "connected": - print(f"[{self.name}] Bridge ready (status: connected)") - break + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2) + ) as resp, + ): + if resp.status == 200: + http_ready = True + data = await resp.json() + if data.get("status") == "connected": + print(f"[{self.name}] Bridge ready (status: connected)") + break except Exception: continue @@ -237,7 +243,7 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] Check log: {self._bridge_log}") self._close_bridge_log() return False - + # Phase 2: HTTP is up but WhatsApp may still be connecting. # Give it more time to authenticate with saved credentials. if data.get("status") != "connected": @@ -250,16 +256,17 @@ class WhatsAppAdapter(BasePlatformAdapter): self._close_bridge_log() return False try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://localhost:{self._bridge_port}/health", - timeout=aiohttp.ClientTimeout(total=2) - ) as resp: - if resp.status == 200: - data = await resp.json() - if data.get("status") == "connected": - print(f"[{self.name}] Bridge ready (status: connected)") - break + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2) + ) as resp, + ): + if resp.status == 200: + data = await resp.json() + if data.get("status") == "connected": + print(f"[{self.name}] Bridge ready (status: connected)") + break except Exception: continue else: @@ -268,19 +275,19 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] ⚠ WhatsApp not connected after 30s") print(f"[{self.name}] Bridge log: {self._bridge_log}") print(f"[{self.name}] If session expired, re-pair: hermes whatsapp") - + # Start message polling task asyncio.create_task(self._poll_messages()) - + self._running = True print(f"[{self.name}] Bridge started on port {self._bridge_port}") return True - + except Exception as e: logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) self._close_bridge_log() return False - + def _close_bridge_log(self) -> None: """Close the bridge log file handle if open.""" if self._bridge_log_fh: @@ -296,6 +303,7 @@ class WhatsAppAdapter(BasePlatformAdapter): try: # Kill the entire process group so child node processes die too import signal + try: if _IS_WINDOWS: self._bridge_process.terminate() @@ -314,29 +322,25 @@ class WhatsAppAdapter(BasePlatformAdapter): self._bridge_process.kill() except Exception as e: print(f"[{self.name}] Error stopping bridge: {e}") - + # Also kill any orphaned bridge processes on our port _kill_port_process(self._bridge_port) - + self._running = False self._bridge_process = None self._close_bridge_log() print(f"[{self.name}] Disconnected") - + async def send( - self, - chat_id: str, - content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None ) -> SendResult: """Send a message via the WhatsApp bridge.""" if not self._running: return SendResult(success=False, error="Not connected") - + try: import aiohttp - + async with aiohttp.ClientSession() as session: payload = { "chatId": chat_id, @@ -344,28 +348,19 @@ class WhatsAppAdapter(BasePlatformAdapter): } if reply_to: payload["replyTo"] = reply_to - + async with session.post( - f"http://localhost:{self._bridge_port}/send", - json=payload, - timeout=aiohttp.ClientTimeout(total=30) + f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30) ) as resp: if resp.status == 200: data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data - ) + return SendResult(success=True, message_id=data.get("messageId"), raw_response=data) else: error = await resp.text() return SendResult(success=False, error=error) - + except ImportError: - return SendResult( - success=False, - error="aiohttp not installed. Run: pip install aiohttp" - ) + return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp") except Exception as e: return SendResult(success=False, error=str(e)) @@ -380,21 +375,24 @@ class WhatsAppAdapter(BasePlatformAdapter): return SendResult(success=False, error="Not connected") try: import aiohttp - async with aiohttp.ClientSession() as session: - async with session.post( + + async with ( + aiohttp.ClientSession() as session, + session.post( f"http://localhost:{self._bridge_port}/edit", json={ "chatId": chat_id, "messageId": message_id, "message": content, }, - timeout=aiohttp.ClientTimeout(total=15) - ) as resp: - if resp.status == 200: - return SendResult(success=True, message_id=message_id) - else: - error = await resp.text() - return SendResult(success=False, error=error) + timeout=aiohttp.ClientTimeout(total=15), + ) as resp, + ): + if resp.status == 200: + return SendResult(success=True, message_id=message_id) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -403,8 +401,8 @@ class WhatsAppAdapter(BasePlatformAdapter): chat_id: str, file_path: str, media_type: str, - caption: Optional[str] = None, - file_name: Optional[str] = None, + caption: str | None = None, + file_name: str | None = None, ) -> SendResult: """Send any media file via bridge /send-media endpoint.""" if not self._running: @@ -415,7 +413,7 @@ class WhatsAppAdapter(BasePlatformAdapter): if not os.path.exists(file_path): return SendResult(success=False, error=f"File not found: {file_path}") - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "chatId": chat_id, "filePath": file_path, "mediaType": media_type, @@ -425,22 +423,24 @@ class WhatsAppAdapter(BasePlatformAdapter): if file_name: payload["fileName"] = file_name - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( f"http://localhost:{self._bridge_port}/send-media", json=payload, timeout=aiohttp.ClientTimeout(total=120), - ) as resp: - if resp.status == 200: - data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data, - ) - else: - error = await resp.text() - return SendResult(success=False, error=error) + ) as resp, + ): + if resp.status == 200: + data = await resp.json() + return SendResult( + success=True, + message_id=data.get("messageId"), + raw_response=data, + ) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -449,8 +449,8 @@ class WhatsAppAdapter(BasePlatformAdapter): self, chat_id: str, image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Download image URL to cache, send natively via bridge.""" try: @@ -463,8 +463,8 @@ class WhatsAppAdapter(BasePlatformAdapter): self, chat_id: str, image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a local image file natively via bridge.""" return await self._send_media_to_bridge(chat_id, image_path, "image", caption) @@ -473,8 +473,8 @@ class WhatsAppAdapter(BasePlatformAdapter): self, chat_id: str, video_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a video natively via bridge — plays inline in WhatsApp.""" return await self._send_media_to_bridge(chat_id, video_path, "video", caption) @@ -483,13 +483,16 @@ class WhatsAppAdapter(BasePlatformAdapter): self, chat_id: str, file_path: str, - caption: Optional[str] = None, - file_name: Optional[str] = None, - reply_to: Optional[str] = None, + caption: str | None = None, + file_name: str | None = None, + reply_to: str | None = None, ) -> SendResult: """Send a document/file as a downloadable attachment via bridge.""" return await self._send_media_to_bridge( - chat_id, file_path, "document", caption, + chat_id, + file_path, + "document", + caption, file_name or os.path.basename(file_path), ) @@ -497,44 +500,45 @@ class WhatsAppAdapter(BasePlatformAdapter): """Send typing indicator via bridge.""" if not self._running: return - + try: import aiohttp - + async with aiohttp.ClientSession() as session: await session.post( f"http://localhost:{self._bridge_port}/typing", json={"chatId": chat_id}, - timeout=aiohttp.ClientTimeout(total=5) + timeout=aiohttp.ClientTimeout(total=5), ) except Exception: pass # Ignore typing indicator failures - - async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + + async def get_chat_info(self, chat_id: str) -> dict[str, Any]: """Get information about a WhatsApp chat.""" if not self._running: return {"name": "Unknown", "type": "dm"} - + try: import aiohttp - - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://localhost:{self._bridge_port}/chat/{chat_id}", - timeout=aiohttp.ClientTimeout(total=10) - ) as resp: - if resp.status == 200: - data = await resp.json() - return { - "name": data.get("name", chat_id), - "type": "group" if data.get("isGroup") else "dm", - "participants": data.get("participants", []), - } + + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10) + ) as resp, + ): + if resp.status == 200: + data = await resp.json() + return { + "name": data.get("name", chat_id), + "type": "group" if data.get("isGroup") else "dm", + "participants": data.get("participants", []), + } except Exception as e: logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e) - + return {"name": chat_id, "type": "dm"} - + async def _poll_messages(self) -> None: """Poll the bridge for incoming messages.""" try: @@ -542,29 +546,30 @@ class WhatsAppAdapter(BasePlatformAdapter): except ImportError: print(f"[{self.name}] aiohttp not installed, message polling disabled") return - + while self._running: try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://localhost:{self._bridge_port}/messages", - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - if resp.status == 200: - messages = await resp.json() - for msg_data in messages: - event = await self._build_message_event(msg_data) - if event: - await self.handle_message(event) + async with ( + aiohttp.ClientSession() as session, + session.get( + f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30) + ) as resp, + ): + if resp.status == 200: + messages = await resp.json() + for msg_data in messages: + event = await self._build_message_event(msg_data) + if event: + await self.handle_message(event) except asyncio.CancelledError: break except Exception as e: print(f"[{self.name}] Poll error: {e}") await asyncio.sleep(5) - + await asyncio.sleep(1) # Poll interval - - async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]: + + async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None: """Build a MessageEvent from bridge message data, downloading images to cache.""" try: # Determine message type @@ -579,11 +584,11 @@ class WhatsAppAdapter(BasePlatformAdapter): msg_type = MessageType.VOICE else: msg_type = MessageType.DOCUMENT - + # Determine chat type is_group = data.get("isGroup", False) chat_type = "group" if is_group else "dm" - + # Build source source = self.build_source( chat_id=data.get("chatId", ""), @@ -592,7 +597,7 @@ class WhatsAppAdapter(BasePlatformAdapter): user_id=data.get("senderId"), user_name=data.get("senderName"), ) - + # Download image media URLs to the local cache so the vision tool # can access them reliably regardless of URL expiration. raw_urls = data.get("mediaUrls", []) @@ -622,7 +627,7 @@ class WhatsAppAdapter(BasePlatformAdapter): else: cached_urls.append(url) media_types.append("unknown") - + return MessageEvent( text=data.get("body", ""), message_type=msg_type, @@ -635,4 +640,3 @@ class WhatsAppAdapter(BasePlatformAdapter): except Exception as e: print(f"[{self.name}] Error building event: {e}") return None - diff --git a/gateway/run.py b/gateway/run.py index 6dd1a280a5..ef331c8941 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -8,7 +8,7 @@ This module provides: Usage: # Start the gateway python -m gateway.run - + # Or from CLI python cli.py --gateway """ @@ -17,13 +17,13 @@ import asyncio import logging import os import re -import sys import signal +import sys import threading +from datetime import datetime from logging.handlers import RotatingFileHandler from pathlib import Path -from datetime import datetime -from typing import Dict, Optional, Any, List +from typing import Any # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -33,7 +33,8 @@ _hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) # Load environment variables from ~/.hermes/.env first from dotenv import load_dotenv -_env_path = _hermes_home / '.env' + +_env_path = _hermes_home / ".env" if _env_path.exists(): try: load_dotenv(_env_path, encoding="utf-8") @@ -44,10 +45,11 @@ load_dotenv() # Bridge config.yaml values into the environment so os.getenv() picks them up. # config.yaml is authoritative for terminal settings — overrides .env. -_config_path = _hermes_home / 'config.yaml' +_config_path = _hermes_home / "config.yaml" if _config_path.exists(): try: import yaml as _yaml + with open(_config_path) as _f: _cfg = _yaml.safe_load(_f) or {} # Top-level simple values (fallback only — don't override .env) @@ -101,8 +103,8 @@ if _config_path.exists(): _auxiliary_cfg = _cfg.get("auxiliary", {}) if _auxiliary_cfg and isinstance(_auxiliary_cfg, dict): _aux_task_env = { - "vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"), - "web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"), + "vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"), + "web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"), } for _task_key, (_prov_env, _model_env) in _aux_task_env.items(): _task_cfg = _auxiliary_cfg.get(_task_key, {}) @@ -147,20 +149,20 @@ if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"): os.environ["TERMINAL_CWD"] = messaging_cwd from gateway.config import ( - Platform, GatewayConfig, + Platform, load_gateway_config, ) +from gateway.delivery import DeliveryRouter +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType from gateway.session import ( - SessionStore, - SessionSource, SessionContext, + SessionSource, + SessionStore, build_session_context, build_session_context_prompt, build_session_key, ) -from gateway.delivery import DeliveryRouter, DeliveryTarget -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType logger = logging.getLogger(__name__) @@ -168,8 +170,8 @@ logger = logging.getLogger(__name__) def _resolve_runtime_agent_kwargs() -> dict: """Resolve provider credentials for gateway-created AIAgent instances.""" from hermes_cli.runtime_provider import ( - resolve_runtime_provider, format_runtime_provider_error, + resolve_runtime_provider, ) try: @@ -190,14 +192,14 @@ def _resolve_runtime_agent_kwargs() -> dict: class GatewayRunner: """ Main gateway controller. - + Manages the lifecycle of all platform adapters and routes messages to/from the agent. """ - - def __init__(self, config: Optional[GatewayConfig] = None): + + def __init__(self, config: GatewayConfig | None = None): self.config = config or load_gateway_config() - self.adapters: Dict[Platform, BasePlatformAdapter] = {} + self.adapters: dict[Platform, BasePlatformAdapter] = {} # Load ephemeral config from config.yaml / env vars. # Both are injected at API-call time only and never persisted. @@ -209,39 +211,44 @@ class GatewayRunner: # Wire process registry into session store for reset protection from tools.process_registry import process_registry + self.session_store = SessionStore( - self.config.sessions_dir, self.config, + self.config.sessions_dir, + self.config, has_active_processes_fn=lambda key: process_registry.has_active_for_session(key), ) self.delivery_router = DeliveryRouter(self.config) self._running = False self._shutdown_event = asyncio.Event() - + # Track running agents per session for interrupt support # Key: session_key, Value: AIAgent instance - self._running_agents: Dict[str, Any] = {} - self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt - + self._running_agents: dict[str, Any] = {} + self._pending_messages: dict[str, str] = {} # Queued messages during interrupt + # Track pending exec approvals per session # Key: session_key, Value: {"command": str, "pattern_key": str} - self._pending_approvals: Dict[str, Dict[str, str]] = {} - + self._pending_approvals: dict[str, dict[str, str]] = {} + # Initialize session database for session_search tool support self._session_db = None try: from hermes_state import SessionDB + self._session_db = SessionDB() except Exception as e: logger.debug("SQLite session store not available: %s", e) - + # DM pairing store for code-based user authorization from gateway.pairing import PairingStore + self.pairing_store = PairingStore() - + # Event hook system from gateway.hooks import HookRegistry + self.hooks = HookRegistry() - + def _flush_memories_for_session(self, old_session_id: str): """Prompt the agent to save memories/skills before context is lost. @@ -254,6 +261,7 @@ class GatewayRunner: return from run_agent import AIAgent + runtime_kwargs = _resolve_runtime_agent_kwargs() if not runtime_kwargs.get("api_key"): return @@ -300,20 +308,22 @@ class GatewayRunner: """Run the sync memory flush in a thread pool so it won't block the event loop.""" loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id) - + @staticmethod - def _load_prefill_messages() -> List[Dict[str, Any]]: + def _load_prefill_messages() -> list[dict[str, Any]]: """Load ephemeral prefill messages from config or env var. - + Checks HERMES_PREFILL_MESSAGES_FILE env var first, then falls back to the prefill_messages_file key in ~/.hermes/config.yaml. Relative paths are resolved from ~/.hermes/. """ import json as _json + file_path = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") if not file_path: try: import yaml as _y + cfg_path = _hermes_home / "config.yaml" if cfg_path.exists(): with open(cfg_path) as _f: @@ -330,7 +340,7 @@ class GatewayRunner: logger.warning("Prefill messages file not found: %s", path) return [] try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: data = _json.load(f) if not isinstance(data, list): logger.warning("Prefill messages file must contain a JSON array: %s", path) @@ -343,7 +353,7 @@ class GatewayRunner: @staticmethod def _load_ephemeral_system_prompt() -> str: """Load ephemeral system prompt from config or env var. - + Checks HERMES_EPHEMERAL_SYSTEM_PROMPT env var first, then falls back to agent.system_prompt in ~/.hermes/config.yaml. """ @@ -352,6 +362,7 @@ class GatewayRunner: return prompt try: import yaml as _y + cfg_path = _hermes_home / "config.yaml" if cfg_path.exists(): with open(cfg_path) as _f: @@ -364,7 +375,7 @@ class GatewayRunner: @staticmethod def _load_reasoning_config() -> dict | None: """Load reasoning effort from config or env var. - + Checks HERMES_REASONING_EFFORT env var first, then agent.reasoning_effort in config.yaml. Valid: "xhigh", "high", "medium", "low", "minimal", "none". Returns None to use default (medium). @@ -373,6 +384,7 @@ class GatewayRunner: if not effort: try: import yaml as _y + cfg_path = _hermes_home / "config.yaml" if cfg_path.exists(): with open(cfg_path) as _f: @@ -396,6 +408,7 @@ class GatewayRunner: """Load OpenRouter provider routing preferences from config.yaml.""" try: import yaml as _y + cfg_path = _hermes_home / "config.yaml" if cfg_path.exists(): with open(cfg_path) as _f: @@ -414,6 +427,7 @@ class GatewayRunner: """ try: import yaml as _y + cfg_path = _hermes_home / "config.yaml" if cfg_path.exists(): with open(cfg_path) as _f: @@ -428,18 +442,22 @@ class GatewayRunner: async def start(self) -> bool: """ Start the gateway and all configured platform adapters. - + Returns True if at least one adapter connected successfully. """ logger.info("Starting Hermes Gateway...") logger.info("Session storage: %s", self.config.sessions_dir) - + # Warn if no user allowlists are configured and open access is not opted in _any_allowlist = any( os.getenv(v) - for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", - "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", - "GATEWAY_ALLOWED_USERS") + for v in ( + "TELEGRAM_ALLOWED_USERS", + "DISCORD_ALLOWED_USERS", + "WHATSAPP_ALLOWED_USERS", + "SLACK_ALLOWED_USERS", + "GATEWAY_ALLOWED_USERS", + ) ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") if not _any_allowlist and not _allow_all: @@ -448,34 +466,35 @@ class GatewayRunner: "Set GATEWAY_ALLOW_ALL_USERS=true in ~/.hermes/.env to allow open access, " "or configure platform allowlists (e.g., TELEGRAM_ALLOWED_USERS=your_id)." ) - + # Discover and load event hooks self.hooks.discover_and_load() - + # Recover background processes from checkpoint (crash recovery) try: from tools.process_registry import process_registry + recovered = process_registry.recover_from_checkpoint() if recovered: logger.info("Recovered %s background process(es) from previous run", recovered) except Exception as e: logger.warning("Process checkpoint recovery: %s", e) - + connected_count = 0 - + # Initialize and connect each configured platform for platform, platform_config in self.config.platforms.items(): if not platform_config.enabled: continue - + adapter = self._create_adapter(platform, platform_config) if not adapter: logger.warning("No adapter available for %s", platform.value) continue - + # Set up message handler adapter.set_message_handler(self._handle_message) - + # Try to connect logger.info("Connecting to %s...", platform.value) try: @@ -488,36 +507,40 @@ class GatewayRunner: logger.warning("✗ %s failed to connect", platform.value) except Exception as e: logger.error("✗ %s error: %s", platform.value, e) - + if connected_count == 0: logger.warning("No messaging platforms connected.") logger.info("Gateway will continue running for cron job execution.") - + # Update delivery router with adapters self.delivery_router.adapters = self.adapters - + self._running = True - + # Emit gateway:startup hook hook_count = len(self.hooks.loaded_hooks) if hook_count: logger.info("%s hook(s) loaded", hook_count) - await self.hooks.emit("gateway:startup", { - "platforms": [p.value for p in self.adapters.keys()], - }) - + await self.hooks.emit( + "gateway:startup", + { + "platforms": [p.value for p in self.adapters.keys()], + }, + ) + if connected_count > 0: logger.info("Gateway running with %s platform(s)", connected_count) - + # Build initial channel directory for send_message name resolution try: from gateway.channel_directory import build_channel_directory + directory = build_channel_directory(self.adapters) ch_count = sum(len(chs) for chs in directory.get("platforms", {}).values()) logger.info("Channel directory built: %d target(s)", ch_count) except Exception as e: logger.warning("Channel directory build failed: %s", e) - + # Check if we're restarting after a /update command await self._send_update_notification() @@ -525,12 +548,12 @@ class GatewayRunner: asyncio.create_task(self._session_expiry_watcher()) logger.info("Press Ctrl+C to stop") - + return True - + async def _session_expiry_watcher(self, interval: int = 300): """Background task that proactively flushes memories for expired sessions. - + Runs every `interval` seconds (default 5 min). For each session that has expired according to its reset policy, flushes memories in a thread pool and marks the session so it won't be flushed again. @@ -550,7 +573,8 @@ class GatewayRunner: # Session has expired — flush memories in the background logger.info( "Session %s expired (key=%s), flushing memories proactively", - entry.session_id, key, + entry.session_id, + key, ) try: await self._async_flush_memories(entry.session_id) @@ -569,55 +593,56 @@ class GatewayRunner: """Stop the gateway and disconnect all adapters.""" logger.info("Stopping gateway...") self._running = False - + for platform, adapter in self.adapters.items(): try: await adapter.disconnect() logger.info("✓ %s disconnected", platform.value) except Exception as e: logger.error("✗ %s disconnect error: %s", platform.value, e) - + self.adapters.clear() self._shutdown_event.set() - + from gateway.status import remove_pid_file + remove_pid_file() - + logger.info("Gateway stopped") - + async def wait_for_shutdown(self) -> None: """Wait for shutdown signal.""" await self._shutdown_event.wait() - - def _create_adapter( - self, - platform: Platform, - config: Any - ) -> Optional[BasePlatformAdapter]: + + def _create_adapter(self, platform: Platform, config: Any) -> BasePlatformAdapter | None: """Create the appropriate adapter for a platform.""" if platform == Platform.TELEGRAM: from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements + if not check_telegram_requirements(): logger.warning("Telegram: python-telegram-bot not installed") return None return TelegramAdapter(config) - + elif platform == Platform.DISCORD: from gateway.platforms.discord import DiscordAdapter, check_discord_requirements + if not check_discord_requirements(): logger.warning("Discord: discord.py not installed") return None return DiscordAdapter(config) - + elif platform == Platform.WHATSAPP: from gateway.platforms.whatsapp import WhatsAppAdapter, check_whatsapp_requirements + if not check_whatsapp_requirements(): logger.warning("WhatsApp: Node.js not installed or bridge not configured") return None return WhatsAppAdapter(config) - + elif platform == Platform.SLACK: from gateway.platforms.slack import SlackAdapter, check_slack_requirements + if not check_slack_requirements(): logger.warning("Slack: slack-bolt not installed. Run: pip install 'hermes-agent[slack]'") return None @@ -625,6 +650,7 @@ class GatewayRunner: elif platform == Platform.SIGNAL: from gateway.platforms.signal import SignalAdapter, check_signal_requirements + if not check_signal_requirements(): logger.warning("Signal: SIGNAL_HTTP_URL or SIGNAL_ACCOUNT not configured") return None @@ -632,17 +658,18 @@ class GatewayRunner: elif platform == Platform.HOMEASSISTANT: from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements + if not check_ha_requirements(): logger.warning("HomeAssistant: aiohttp not installed or HASS_TOKEN not set") return None return HomeAssistantAdapter(config) return None - + def _is_user_authorized(self, source: SessionSource) -> bool: """ Check if a user is authorized to use the bot. - + Checks in order: 1. Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) 2. Environment variable allowlists (TELEGRAM_ALLOWED_USERS, etc.) @@ -705,11 +732,11 @@ class GatewayRunner: if "@" in user_id: check_ids.add(user_id.split("@")[0]) return bool(check_ids & allowed_ids) - - async def _handle_message(self, event: MessageEvent) -> Optional[str]: + + async def _handle_message(self, event: MessageEvent) -> str | None: """ Handle an incoming message from any platform. - + This is the core message processing pipeline: 1. Check user authorization 2. Check for commands (/new, /reset, etc.) @@ -720,16 +747,14 @@ class GatewayRunner: 7. Return response """ source = event.source - + # Check if user is authorized if not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) # In DMs: offer pairing code. In groups: silently ignore. if source.chat_type == "dm": platform_name = source.platform.value if source.platform else "unknown" - code = self.pairing_store.generate_code( - platform_name, source.user_id, source.user_name or "" - ) + code = self.pairing_store.generate_code(platform_name, source.user_id, source.user_name or "") if code: adapter = self.adapters.get(source.platform) if adapter: @@ -738,18 +763,16 @@ class GatewayRunner: f"Hi~ I don't recognize you yet!\n\n" f"Here's your pairing code: `{code}`\n\n" f"Ask the bot owner to run:\n" - f"`hermes pairing approve {platform_name} {code}`" + f"`hermes pairing approve {platform_name} {code}`", ) else: adapter = self.adapters.get(source.platform) if adapter: await adapter.send( - source.chat_id, - "Too many pairing requests right now~ " - "Please try again later!" + source.chat_id, "Too many pairing requests right now~ Please try again later!" ) return None - + # PRIORITY: If an agent is already running for this session, interrupt it # immediately. This is before command parsing to minimize latency -- the # user's "stop" message reaches the agent as fast as possible. @@ -763,50 +786,71 @@ class GatewayRunner: else: self._pending_messages[_quick_key] = event.text return None - + # Check for commands command = event.get_command() - + # Emit command:* hook for any recognized slash command - _known_commands = {"new", "reset", "help", "status", "stop", "model", - "personality", "retry", "undo", "sethome", "set-home", - "compress", "usage", "insights", "reload-mcp", "reload_mcp", - "update", "title", "resume", "provider"} + _known_commands = { + "new", + "reset", + "help", + "status", + "stop", + "model", + "personality", + "retry", + "undo", + "sethome", + "set-home", + "compress", + "usage", + "insights", + "reload-mcp", + "reload_mcp", + "update", + "title", + "resume", + "provider", + } if command and command in _known_commands: - await self.hooks.emit(f"command:{command}", { - "platform": source.platform.value if source.platform else "", - "user_id": source.user_id, - "command": command, - "args": event.get_command_args().strip(), - }) - + await self.hooks.emit( + f"command:{command}", + { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "command": command, + "args": event.get_command_args().strip(), + }, + ) + if command in ["new", "reset"]: return await self._handle_reset_command(event) - + if command == "help": return await self._handle_help_command(event) - + if command == "status": return await self._handle_status_command(event) - + if command == "stop": return await self._handle_stop_command(event) - + if command == "model": return await self._handle_model_command(event) - + if command == "provider": return await self._handle_provider_command(event) - + if command == "personality": return await self._handle_personality_command(event) - + if command == "retry": return await self._handle_retry_command(event) - + if command == "undo": return await self._handle_undo_command(event) - + if command in ["sethome", "set-home"]: return await self._handle_set_home_command(event) @@ -830,11 +874,12 @@ class GatewayRunner: if command == "resume": return await self._handle_resume_command(event) - + # Skill slash commands: /skill-name loads the skill and sends to agent if command: try: - from agent.skill_commands import get_skill_commands, build_skill_invocation_message + from agent.skill_commands import build_skill_invocation_message, get_skill_commands + skill_cmds = get_skill_commands() cmd_key = f"/{command}" if cmd_key in skill_cmds: @@ -845,7 +890,7 @@ class GatewayRunner: # Fall through to normal message processing with skill content except Exception as e: logger.debug("Skill command check failed (non-fatal): %s", e) - + # Check for pending exec approval responses session_key_preview = build_session_key(source) if session_key_preview in self._pending_approvals: @@ -855,8 +900,9 @@ class GatewayRunner: cmd = approval["command"] pattern_key = approval.get("pattern_key", "") logger.info("User approved dangerous command: %s...", cmd[:60]) - from tools.terminal_tool import terminal_tool from tools.approval import approve_session + from tools.terminal_tool import terminal_tool + approve_session(session_key_preview, pattern_key) result = terminal_tool(command=cmd, force=True) return f"✅ Command approved and executed.\n\n```\n{result[:3500]}\n```" @@ -864,46 +910,47 @@ class GatewayRunner: self._pending_approvals.pop(session_key_preview) return "❌ Command denied." # If it's not clearly an approval/denial, fall through to normal processing - + # Get or create session session_entry = self.session_store.get_or_create_session(source) session_key = session_entry.session_key - + # Emit session:start for new or auto-reset sessions - _is_new_session = ( - session_entry.created_at == session_entry.updated_at - or getattr(session_entry, "was_auto_reset", False) + _is_new_session = session_entry.created_at == session_entry.updated_at or getattr( + session_entry, "was_auto_reset", False ) if _is_new_session: - await self.hooks.emit("session:start", { - "platform": source.platform.value if source.platform else "", - "user_id": source.user_id, - "session_id": session_entry.session_id, - "session_key": session_key, - }) - + await self.hooks.emit( + "session:start", + { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "session_id": session_entry.session_id, + "session_key": session_key, + }, + ) + # Build session context context = build_session_context(source, self.config, session_entry) - + # Set environment variables for tools self._set_session_env(context) - + # Build the context prompt to inject context_prompt = build_session_context_prompt(context) - + # If the previous session expired and was auto-reset, prepend a notice # so the agent knows this is a fresh conversation (not an intentional /reset). - if getattr(session_entry, 'was_auto_reset', False): + if getattr(session_entry, "was_auto_reset", False): context_prompt = ( "[System note: The user's previous session expired due to inactivity. " - "This is a fresh conversation with no prior context.]\n\n" - + context_prompt + "This is a fresh conversation with no prior context.]\n\n" + context_prompt ) session_entry.was_auto_reset = False - + # Load conversation history from transcript history = self.session_store.load_transcript(session_entry.session_id) - + # ----------------------------------------------------------------- # Session hygiene: auto-compress pathologically large transcripts # @@ -931,6 +978,7 @@ class GatewayRunner: _hyg_cfg_path = _hermes_home / "config.yaml" if _hyg_cfg_path.exists(): import yaml as _hyg_yaml + with open(_hyg_cfg_path) as _hyg_f: _hyg_data = _hyg_yaml.safe_load(_hyg_f) or {} @@ -944,27 +992,19 @@ class GatewayRunner: # Read compression settings _comp_cfg = _hyg_data.get("compression", {}) if isinstance(_comp_cfg, dict): - _hyg_threshold_pct = float( - _comp_cfg.get("threshold", _hyg_threshold_pct) - ) - _hyg_compression_enabled = str( - _comp_cfg.get("enabled", True) - ).lower() in ("true", "1", "yes") + _hyg_threshold_pct = float(_comp_cfg.get("threshold", _hyg_threshold_pct)) + _hyg_compression_enabled = str(_comp_cfg.get("enabled", True)).lower() in ("true", "1", "yes") except Exception: pass # Also check env overrides (same as run_agent.py) - _hyg_threshold_pct = float( - os.getenv("CONTEXT_COMPRESSION_THRESHOLD", str(_hyg_threshold_pct)) - ) + _hyg_threshold_pct = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", str(_hyg_threshold_pct))) if os.getenv("CONTEXT_COMPRESSION_ENABLED", "").lower() in ("false", "0", "no"): _hyg_compression_enabled = False if _hyg_compression_enabled: _hyg_context_length = get_model_context_length(_hyg_model) - _compress_token_threshold = int( - _hyg_context_length * _hyg_threshold_pct - ) + _compress_token_threshold = int(_hyg_context_length * _hyg_threshold_pct) # Warn if still huge after compression (95% of context) _warn_token_threshold = int(_hyg_context_length * 0.95) @@ -977,7 +1017,8 @@ class GatewayRunner: logger.info( "Session hygiene: %s messages, ~%s tokens — auto-compressing " "(threshold: %s%% of %s = %s tokens)", - _msg_count, f"{_approx_tokens:,}", + _msg_count, + f"{_approx_tokens:,}", int(_hyg_threshold_pct * 100), f"{_hyg_context_length:,}", f"{_compress_token_threshold:,}", @@ -989,7 +1030,7 @@ class GatewayRunner: await _hyg_adapter.send( source.chat_id, f"🗜️ Session is large ({_msg_count} messages, " - f"~{_approx_tokens:,} tokens). Auto-compressing..." + f"~{_approx_tokens:,} tokens). Auto-compressing...", ) except Exception: pass @@ -1002,8 +1043,7 @@ class GatewayRunner: _hyg_msgs = [ {"role": m.get("role"), "content": m.get("content")} for m in history - if m.get("role") in ("user", "assistant") - and m.get("content") + if m.get("role") in ("user", "assistant") and m.get("content") ] if len(_hyg_msgs) >= 4: @@ -1019,25 +1059,23 @@ class GatewayRunner: _compressed, _ = await loop.run_in_executor( None, lambda: _hyg_agent._compress_context( - _hyg_msgs, "", + _hyg_msgs, + "", approx_tokens=_approx_tokens, ), ) - self.session_store.rewrite_transcript( - session_entry.session_id, _compressed - ) + self.session_store.rewrite_transcript(session_entry.session_id, _compressed) history = _compressed _new_count = len(_compressed) - _new_tokens = estimate_messages_tokens_rough( - _compressed - ) + _new_tokens = estimate_messages_tokens_rough(_compressed) logger.info( - "Session hygiene: compressed %s → %s msgs, " - "~%s → ~%s tokens", - _msg_count, _new_count, - f"{_approx_tokens:,}", f"{_new_tokens:,}", + "Session hygiene: compressed %s → %s msgs, ~%s → ~%s tokens", + _msg_count, + _new_count, + f"{_approx_tokens:,}", + f"{_new_tokens:,}", ) if _hyg_adapter: @@ -1047,7 +1085,7 @@ class GatewayRunner: f"🗜️ Compressed: {_msg_count} → " f"{_new_count} messages, " f"~{_approx_tokens:,} → " - f"~{_new_tokens:,} tokens" + f"~{_new_tokens:,} tokens", ) except Exception: pass @@ -1055,8 +1093,7 @@ class GatewayRunner: # Still too large after compression — warn user if _new_tokens >= _warn_token_threshold: logger.warning( - "Session hygiene: still ~%s tokens after " - "compression — suggesting /reset", + "Session hygiene: still ~%s tokens after compression — suggesting /reset", f"{_new_tokens:,}", ) if _hyg_adapter: @@ -1067,15 +1104,13 @@ class GatewayRunner: "after compression " f"(~{_new_tokens:,} tokens). " "Consider using /reset to start " - "fresh if you experience issues." + "fresh if you experience issues.", ) except Exception: pass except Exception as e: - logger.warning( - "Session hygiene auto-compress failed: %s", e - ) + logger.warning("Session hygiene auto-compress failed: %s", e) # Compression failed and session is dangerously large if _approx_tokens >= _warn_token_threshold: _hyg_adapter = self.adapters.get(source.platform) @@ -1088,7 +1123,7 @@ class GatewayRunner: f"~{_approx_tokens:,} tokens) and " "auto-compression failed. Consider " "using /compress or /reset to avoid " - "issues." + "issues.", ) except Exception: pass @@ -1100,7 +1135,7 @@ class GatewayRunner: "Briefly introduce yourself and mention that /help shows available commands. " "Keep the introduction concise -- one or two sentences max.]" ) - + # One-time prompt if no home channel is set for this platform if not history and source.platform and source.platform != Platform.LOCAL: platform_name = source.platform.value @@ -1114,9 +1149,9 @@ class GatewayRunner: f"A home channel is where Hermes delivers cron job results " f"and cross-platform messages.\n\n" f"Type /sethome to make this chat your home channel, " - f"or ignore to skip." + f"or ignore to skip.", ) - + # ----------------------------------------------------------------- # Auto-analyze images sent by the user # @@ -1135,17 +1170,12 @@ class GatewayRunner: for i, path in enumerate(event.media_urls): # Check media_types if available; otherwise infer from message type mtype = event.media_types[i] if i < len(event.media_types) else "" - is_image = ( - mtype.startswith("image/") - or event.message_type == MessageType.PHOTO - ) + is_image = mtype.startswith("image/") or event.message_type == MessageType.PHOTO if is_image: image_paths.append(path) if image_paths: - message_text = await self._enrich_message_with_vision( - message_text, image_paths - ) - + message_text = await self._enrich_message_with_vision(message_text, image_paths) + # ----------------------------------------------------------------- # Auto-transcribe voice/audio messages sent by the user # ----------------------------------------------------------------- @@ -1153,16 +1183,11 @@ class GatewayRunner: audio_paths = [] for i, path in enumerate(event.media_urls): mtype = event.media_types[i] if i < len(event.media_types) else "" - is_audio = ( - mtype.startswith("audio/") - or event.message_type in (MessageType.VOICE, MessageType.AUDIO) - ) + is_audio = mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO) if is_audio: audio_paths.append(path) if audio_paths: - message_text = await self._enrich_message_with_transcription( - message_text, audio_paths - ) + message_text = await self._enrich_message_with_transcription(message_text, audio_paths) # ----------------------------------------------------------------- # Enrich document messages with context notes for the agent @@ -1174,13 +1199,15 @@ class GatewayRunner: continue # Extract display filename by stripping the doc_{uuid12}_ prefix import os as _os + basename = _os.path.basename(path) # Format: doc_<12hex>_ parts = basename.split("_", 2) display_name = parts[2] if len(parts) >= 3 else basename # Sanitize to prevent prompt injection via filenames import re as _re - display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + display_name = _re.sub(r"[^\w.\- ]", "_", display_name) if mtype.startswith("text/"): context_note = ( @@ -1205,7 +1232,7 @@ class GatewayRunner: "message": message_text[:500], } await self.hooks.emit("agent:start", hook_ctx) - + # Run the agent agent_result = await self._run_agent( message=message_text, @@ -1213,21 +1240,25 @@ class GatewayRunner: history=history, source=source, session_id=session_entry.session_id, - session_key=session_key + session_key=session_key, ) - + response = agent_result.get("final_response", "") agent_messages = agent_result.get("messages", []) - + # Emit agent:end hook - await self.hooks.emit("agent:end", { - **hook_ctx, - "response": (response or "")[:500], - }) - + await self.hooks.emit( + "agent:end", + { + **hook_ctx, + "response": (response or "")[:500], + }, + ) + # Check for pending process watchers (check_interval on background processes) try: from tools.process_registry import process_registry + while process_registry.pending_watchers: watcher = process_registry.pending_watchers.pop(0) asyncio.create_task(self._run_process_watcher(watcher)) @@ -1237,18 +1268,19 @@ class GatewayRunner: # Check if the agent encountered a dangerous command needing approval try: from tools.approval import pop_pending + pending = pop_pending(session_key) if pending: self._pending_approvals[session_key] = pending except Exception as e: logger.debug("Failed to check pending approvals: %s", e) - + # Save the full conversation to the transcript, including tool calls. # This preserves the complete agent loop (tool_calls, tool results, # intermediate reasoning) so sessions can be resumed with full context # and transcripts are useful for debugging and training data. ts = datetime.now().isoformat() - + # If this is a fresh session (no history), write the full tool # definitions as the first entry so the transcript is self-describing # -- the same list of dicts sent as tools=[...] in the API request. @@ -1262,26 +1294,24 @@ class GatewayRunner: "model": os.getenv("HERMES_MODEL", ""), "platform": source.platform.value if source.platform else "", "timestamp": ts, - } + }, ) - + # Find only the NEW messages from this turn (skip history we loaded). # Use the filtered history length (history_offset) that was actually # passed to the agent, not len(history) which includes session_meta # entries that were stripped before the agent saw them. history_len = agent_result.get("history_offset", len(history)) new_messages = agent_messages[history_len:] if len(agent_messages) > history_len else [] - + # If no new messages found (edge case), fall back to simple user/assistant if not new_messages: self.session_store.append_to_transcript( - session_entry.session_id, - {"role": "user", "content": message_text, "timestamp": ts} + session_entry.session_id, {"role": "user", "content": message_text, "timestamp": ts} ) if response: self.session_store.append_to_transcript( - session_entry.session_id, - {"role": "assistant", "content": response, "timestamp": ts} + session_entry.session_id, {"role": "assistant", "content": response, "timestamp": ts} ) else: for msg in new_messages: @@ -1290,16 +1320,14 @@ class GatewayRunner: continue # Add timestamp to each message for debugging entry = {**msg, "timestamp": ts} - self.session_store.append_to_transcript( - session_entry.session_id, entry - ) - + self.session_store.append_to_transcript(session_entry.session_id, entry) + # Update session self.session_store.update_session(session_entry.session_key) - + return response - - except Exception as e: + + except Exception: logger.exception("Agent error in session %s", session_key) return ( "Sorry, I encountered an unexpected error. " @@ -1309,14 +1337,14 @@ class GatewayRunner: finally: # Clear session env self._clear_session_env() - + async def _handle_reset_command(self, event: MessageEvent) -> str: """Handle /new or /reset command.""" source = event.source - + # Get existing session key session_key = self.session_store._generate_session_key(source) - + # Flush memories in the background (fire-and-forget) so the user # gets the "Session reset!" response immediately. try: @@ -1325,35 +1353,38 @@ class GatewayRunner: asyncio.create_task(self._async_flush_memories(old_entry.session_id)) except Exception as e: logger.debug("Gateway memory flush on reset failed: %s", e) - + # Reset the session new_entry = self.session_store.reset_session(session_key) - + # Emit session:reset hook - await self.hooks.emit("session:reset", { - "platform": source.platform.value if source.platform else "", - "user_id": source.user_id, - "session_key": session_key, - }) - + await self.hooks.emit( + "session:reset", + { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "session_key": session_key, + }, + ) + if new_entry: return "✨ Session reset! I've started fresh with no memory of our previous conversation." else: # No existing session, just create one self.session_store.get_or_create_session(source, force_new=True) return "✨ New session started!" - + async def _handle_status_command(self, event: MessageEvent) -> str: """Handle /status command.""" source = event.source session_entry = self.session_store.get_or_create_session(source) - + connected_platforms = [p.value for p in self.adapters.keys()] - + # Check if there's an active agent session_key = session_entry.session_key is_running = session_key in self._running_agents - + lines = [ "📊 **Hermes Gateway Status**", "", @@ -1365,22 +1396,22 @@ class GatewayRunner: "", f"**Connected Platforms:** {', '.join(connected_platforms)}", ] - + return "\n".join(lines) - + async def _handle_stop_command(self, event: MessageEvent) -> str: """Handle /stop command - interrupt a running agent.""" source = event.source session_entry = self.session_store.get_or_create_session(source) session_key = session_entry.session_key - + if session_key in self._running_agents: agent = self._running_agents[session_key] agent.interrupt() return "⚡ Stopping the current task... The agent will finish its current step and respond." else: return "No active task to stop." - + async def _handle_help_command(self, event: MessageEvent) -> str: """Handle /help command - list available commands.""" lines = [ @@ -1406,6 +1437,7 @@ class GatewayRunner: ] try: from agent.skill_commands import get_skill_commands + skill_cmds = get_skill_commands() if skill_cmds: lines.append(f"\n⚡ **Skill Commands** ({len(skill_cmds)} installed):") @@ -1414,20 +1446,21 @@ class GatewayRunner: except Exception: pass return "\n".join(lines) - + async def _handle_model_command(self, event: MessageEvent) -> str: """Handle /model command - show or change the current model.""" import yaml + from hermes_cli.models import ( - parse_model_input, - validate_requested_model, + _PROVIDER_LABELS, curated_models_for_provider, normalize_provider, - _PROVIDER_LABELS, + parse_model_input, + validate_requested_model, ) args = event.get_command_args().strip() - config_path = _hermes_home / 'config.yaml' + config_path = _hermes_home / "config.yaml" # Resolve current model and provider from config current = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6" @@ -1450,6 +1483,7 @@ class GatewayRunner: if current_provider == "auto": try: from hermes_cli.auth import resolve_provider as _resolve_provider + current_provider = _resolve_provider(current_provider) except Exception: current_provider = "openrouter" @@ -1488,6 +1522,7 @@ class GatewayRunner: if provider_changed: try: from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=target_provider) api_key = runtime.get("api_key", "") base_url = runtime.get("base_url", "") @@ -1498,6 +1533,7 @@ class GatewayRunner: # Use current provider's base_url from config or registry try: from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=current_provider) api_key = runtime.get("api_key", "") base_url = runtime.get("base_url", "") @@ -1517,7 +1553,11 @@ class GatewayRunner: if not validation.get("accepted"): msg = validation.get("message", "Invalid model") - tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else "" + tip = ( + "\n\nUse `/model` to see available models, `/provider` to see providers" + if "Did you mean" not in msg + else "" + ) return f"⚠️ {msg}{tip}" # Persist to config only if validation approves @@ -1532,7 +1572,7 @@ class GatewayRunner: user_config["model"]["default"] = new_model if provider_changed: user_config["model"]["provider"] = target_provider - with open(config_path, 'w') as f: + with open(config_path, "w") as f: yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) except Exception as e: return f"⚠️ Failed to save model change: {e}" @@ -1558,15 +1598,16 @@ class GatewayRunner: async def _handle_provider_command(self, event: MessageEvent) -> str: """Handle /provider command - show available providers.""" import yaml + from hermes_cli.models import ( + _PROVIDER_LABELS, list_available_providers, normalize_provider, - _PROVIDER_LABELS, ) # Resolve current provider from config current_provider = "openrouter" - config_path = _hermes_home / 'config.yaml' + config_path = _hermes_home / "config.yaml" try: if config_path.exists(): with open(config_path) as f: @@ -1581,6 +1622,7 @@ class GatewayRunner: if current_provider == "auto": try: from hermes_cli.auth import resolve_provider as _resolve_provider + current_provider = _resolve_provider(current_provider) except Exception: current_provider = "openrouter" @@ -1608,17 +1650,17 @@ class GatewayRunner: lines.append("Switch: `/model provider:model-name`") lines.append("Setup: `hermes setup`") return "\n".join(lines) - + async def _handle_personality_command(self, event: MessageEvent) -> str: """Handle /personality command - list or set a personality.""" import yaml args = event.get_command_args().strip().lower() - config_path = _hermes_home / 'config.yaml' + config_path = _hermes_home / "config.yaml" try: if config_path.exists(): - with open(config_path, 'r') as f: + with open(config_path) as f: config = yaml.safe_load(f) or {} personalities = config.get("agent", {}).get("personalities", {}) else: @@ -1636,7 +1678,7 @@ class GatewayRunner: for name, prompt in personalities.items(): preview = prompt[:50] + "..." if len(prompt) > 50 else prompt lines.append(f"• `{name}` — {preview}") - lines.append(f"\nUsage: `/personality `") + lines.append("\nUsage: `/personality `") return "\n".join(lines) if args in personalities: @@ -1647,7 +1689,7 @@ class GatewayRunner: if "agent" not in config or not isinstance(config.get("agent"), dict): config["agent"] = {} config["agent"]["system_prompt"] = new_prompt - with open(config_path, 'w') as f: + with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) except Exception as e: return f"⚠️ Failed to save personality change: {e}" @@ -1659,13 +1701,13 @@ class GatewayRunner: available = ", ".join(f"`{n}`" for n in personalities.keys()) return f"Unknown personality: `{args}`\n\nAvailable: {available}" - + async def _handle_retry_command(self, event: MessageEvent) -> str: """Handle /retry command - re-send the last user message.""" source = event.source session_entry = self.session_store.get_or_create_session(source) history = self.session_store.load_transcript(session_entry.session_id) - + # Find the last user message last_user_msg = None last_user_idx = None @@ -1674,14 +1716,14 @@ class GatewayRunner: last_user_msg = history[i].get("content", "") last_user_idx = i break - + if not last_user_msg: return "No previous message to retry." - + # Truncate history to before the last user message and persist truncated = history[:last_user_idx] self.session_store.rewrite_transcript(session_entry.session_id, truncated) - + # Re-send by creating a fake text event with the old message retry_event = MessageEvent( text=last_user_msg, @@ -1689,63 +1731,64 @@ class GatewayRunner: source=source, raw_message=event.raw_message, ) - + # Let the normal message handler process it return await self._handle_message(retry_event) - + async def _handle_undo_command(self, event: MessageEvent) -> str: """Handle /undo command - remove the last user/assistant exchange.""" source = event.source session_entry = self.session_store.get_or_create_session(source) history = self.session_store.load_transcript(session_entry.session_id) - + # Find the last user message and remove everything from it onward last_user_idx = None for i in range(len(history) - 1, -1, -1): if history[i].get("role") == "user": last_user_idx = i break - + if last_user_idx is None: return "Nothing to undo." - + removed_msg = history[last_user_idx].get("content", "") removed_count = len(history) - last_user_idx self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) - + preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg - return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" - + return f'↩️ Undid {removed_count} message(s).\nRemoved: "{preview}"' + async def _handle_set_home_command(self, event: MessageEvent) -> str: """Handle /sethome command -- set the current chat as the platform's home channel.""" source = event.source platform_name = source.platform.value if source.platform else "unknown" chat_id = source.chat_id chat_name = source.chat_name or chat_id - + env_key = f"{platform_name.upper()}_HOME_CHANNEL" - + # Save to config.yaml try: import yaml - config_path = _hermes_home / 'config.yaml' + + config_path = _hermes_home / "config.yaml" user_config = {} if config_path.exists(): with open(config_path) as f: user_config = yaml.safe_load(f) or {} user_config[env_key] = chat_id - with open(config_path, 'w') as f: + with open(config_path, "w") as f: yaml.dump(user_config, f, default_flow_style=False) # Also set in the current environment so it takes effect immediately os.environ[env_key] = str(chat_id) except Exception as e: return f"Failed to save home channel: {e}" - + return ( f"✅ Home channel set to **{chat_name}** (ID: {chat_id}).\n" f"Cron jobs and cross-platform messages will be delivered here." ) - + async def _handle_compress_command(self, event: MessageEvent) -> str: """Handle /compress command -- manually compress conversation context.""" source = event.source @@ -1756,8 +1799,8 @@ class GatewayRunner: return "Not enough conversation to compress (need at least 4 messages)." try: - from run_agent import AIAgent from agent.model_metadata import estimate_messages_tokens_rough + from run_agent import AIAgent runtime_kwargs = _resolve_runtime_agent_kwargs() if not runtime_kwargs.get("api_key"): @@ -1789,10 +1832,7 @@ class GatewayRunner: new_count = len(compressed) new_tokens = estimate_messages_tokens_rough(compressed) - return ( - f"🗜️ Compressed: {original_count} → {new_count} messages\n" - f"~{approx_tokens:,} → ~{new_tokens:,} tokens" - ) + return f"🗜️ Compressed: {original_count} → {new_count} messages\n~{approx_tokens:,} → ~{new_tokens:,} tokens" except Exception as e: logger.warning("Manual compress failed: %s", e) return f"Compression failed: {e}" @@ -1844,9 +1884,7 @@ class GatewayRunner: # List recent titled sessions for this user/platform try: user_source = source.platform.value if source.platform else None - sessions = self._session_db.list_sessions_rich( - source=user_source, limit=10 - ) + sessions = self._session_db.list_sessions_rich(source=user_source, limit=10) titled = [s for s in sessions if s.get("title")] if not titled: return ( @@ -1870,8 +1908,7 @@ class GatewayRunner: target_id = self._session_db.resolve_session_by_title(name) if not target_id: return ( - f"No session found matching '**{name}**'.\n" - "Use `/resume` with no arguments to see available sessions." + f"No session found matching '**{name}**'.\nUse `/resume` with no arguments to see available sessions." ) # Check if already on that session @@ -1931,6 +1968,7 @@ class GatewayRunner: history = self.session_store.load_transcript(session_entry.session_id) if history: from agent.model_metadata import estimate_messages_tokens_rough + msgs = [m for m in history if m.get("role") in ("user", "assistant") and m.get("content")] approx = estimate_messages_tokens_rough(msgs) return ( @@ -1970,8 +2008,8 @@ class GatewayRunner: i += 1 try: - from hermes_state import SessionDB from agent.insights import InsightsEngine + from hermes_state import SessionDB loop = _asyncio.get_event_loop() @@ -1992,7 +2030,7 @@ class GatewayRunner: """Handle /reload-mcp command -- disconnect and reconnect all MCP servers.""" loop = asyncio.get_event_loop() try: - from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock + from tools.mcp_tool import _load_mcp_config, _lock, _servers, discover_mcp_tools, shutdown_mcp_servers # Capture old server names before shutdown with _lock: @@ -2046,9 +2084,7 @@ class GatewayRunner: } try: session_entry = self.session_store.get_or_create_session(event.source) - self.session_store.append_to_transcript( - session_entry.session_id, reload_msg - ) + self.session_store.append_to_transcript(session_entry.session_id, reload_msg) except Exception: pass # Best-effort; don't fail the reload over a transcript write @@ -2072,7 +2108,7 @@ class GatewayRunner: from datetime import datetime project_root = Path(__file__).parent.parent.resolve() - git_dir = project_root / '.git' + git_dir = project_root / ".git" if not git_dir.exists(): return "✗ Not a git repository — cannot update." @@ -2099,9 +2135,7 @@ class GatewayRunner: systemd_run = shutil.which("systemd-run") if systemd_run: subprocess.Popen( - [systemd_run, "--user", "--scope", - "--unit=hermes-update", "--", - "bash", "-c", update_cmd], + [systemd_run, "--user", "--scope", "--unit=hermes-update", "--", "bash", "-c", update_cmd], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, @@ -2147,7 +2181,7 @@ class GatewayRunner: if adapter and chat_id: # Strip ANSI escape codes for clean display - output = _re.sub(r'\x1b\[[0-9;]*m', '', output).strip() + output = _re.sub(r"\x1b\[[0-9;]*m", "", output).strip() if output: # Truncate if too long for a single message if len(output) > 3500: @@ -2169,17 +2203,17 @@ class GatewayRunner: os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id if context.source.chat_name: os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name - + def _clear_session_env(self) -> None: """Clear session environment variables.""" for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]: if var in os.environ: del os.environ[var] - + async def _enrich_message_with_vision( self, user_text: str, - image_paths: List[str], + image_paths: list[str], ) -> str: """ Auto-analyze user-attached images with the vision tool and prepend @@ -2197,9 +2231,10 @@ class GatewayRunner: Returns: The enriched message string with vision descriptions prepended. """ - from tools.vision_tools import vision_analyze_tool import json as _json + from tools.vision_tools import vision_analyze_tool + analysis_prompt = ( "Describe everything visible in this image in thorough detail. " "Include any text, code, data, objects, people, layout, colors, " @@ -2247,7 +2282,7 @@ class GatewayRunner: async def _enrich_message_with_transcription( self, user_text: str, - audio_paths: List[str], + audio_paths: list[str], ) -> str: """ Auto-transcribe user voice/audio messages using OpenAI Whisper API @@ -2260,9 +2295,10 @@ class GatewayRunner: Returns: The enriched message string with transcriptions prepended. """ - from tools.transcription_tools import transcribe_audio import asyncio + from tools.transcription_tools import transcribe_audio + enriched_parts = [] for path in audio_paths: try: @@ -2270,10 +2306,7 @@ class GatewayRunner: result = await asyncio.to_thread(transcribe_audio, path) if result["success"]: transcript = result["transcript"] - enriched_parts.append( - f'[The user sent a voice message~ ' - f'Here\'s what they said: "{transcript}"]' - ) + enriched_parts.append(f'[The user sent a voice message~ Here\'s what they said: "{transcript}"]') else: error = result.get("error", "unknown error") if "OPENAI_API_KEY" in error or "VOICE_TOOLS_OPENAI_KEY" in error: @@ -2284,8 +2317,7 @@ class GatewayRunner: ) else: enriched_parts.append( - "[The user sent a voice message but I had trouble " - f"transcribing it~ ({error})]" + f"[The user sent a voice message but I had trouble transcribing it~ ({error})]" ) except Exception as e: logger.error("Transcription error: %s", e) @@ -2353,10 +2385,7 @@ class GatewayRunner: elif has_new_output: # New output available -- deliver status update new_output = session.output_buffer[-500:] if session.output_buffer else "" - message_text = ( - f"[Background process {session_id} is still running~ " - f"New output:\n{new_output}]" - ) + message_text = f"[Background process {session_id} is still running~ New output:\n{new_output}]" adapter = None for p, a in self.adapters.items(): if p.value == platform_name: @@ -2374,26 +2403,27 @@ class GatewayRunner: self, message: str, context_prompt: str, - history: List[Dict[str, Any]], + history: list[dict[str, Any]], source: SessionSource, session_id: str, - session_key: str = None - ) -> Dict[str, Any]: + session_key: str = None, + ) -> dict[str, Any]: """ Run the agent with the given message and context. - + Returns the full result dict from run_conversation, including: - "final_response": str (the text to send back) - "messages": list (full conversation including tool calls) - "api_calls": int - "completed": bool - + This is run in a thread pool to not block the event loop. Supports interruption via new messages. """ - from run_agent import AIAgent import queue - + + from run_agent import AIAgent + # Determine toolset based on platform. # Check config.yaml for per-platform overrides, fallback to hardcoded defaults. default_toolset_map = { @@ -2403,19 +2433,20 @@ class GatewayRunner: Platform.WHATSAPP: "hermes-whatsapp", Platform.SLACK: "hermes-slack", } - + # Try to load platform_toolsets from config platform_toolsets_config = {} try: - config_path = _hermes_home / 'config.yaml' + config_path = _hermes_home / "config.yaml" if config_path.exists(): import yaml - with open(config_path, 'r') as f: + + with open(config_path) as f: user_config = yaml.safe_load(f) or {} platform_toolsets_config = user_config.get("platform_toolsets", {}) except Exception as e: logger.debug("Could not load platform_toolsets config: %s", e) - + # Map platform enum to config key platform_config_key = { Platform.LOCAL: "cli", @@ -2424,7 +2455,7 @@ class GatewayRunner: Platform.WHATSAPP: "whatsapp", Platform.SLACK: "slack", }.get(source.platform, "telegram") - + # Use config override if present (list of toolsets), otherwise hardcoded default config_toolsets = platform_toolsets_config.get(platform_config_key) if config_toolsets and isinstance(config_toolsets, list): @@ -2432,7 +2463,7 @@ class GatewayRunner: else: default_toolset = default_toolset_map.get(source.platform, "hermes-telegram") enabled_toolsets = [default_toolset] - + # Tool progress mode from config.yaml: "all", "new", "verbose", "off" # Falls back to env vars for backward compatibility _progress_cfg = {} @@ -2440,32 +2471,29 @@ class GatewayRunner: _tp_cfg_path = _hermes_home / "config.yaml" if _tp_cfg_path.exists(): import yaml as _tp_yaml + with open(_tp_cfg_path) as _tp_f: _tp_data = _tp_yaml.safe_load(_tp_f) or {} _progress_cfg = _tp_data.get("display", {}) except Exception: pass - progress_mode = ( - _progress_cfg.get("tool_progress") - or os.getenv("HERMES_TOOL_PROGRESS_MODE") - or "all" - ) + progress_mode = _progress_cfg.get("tool_progress") or os.getenv("HERMES_TOOL_PROGRESS_MODE") or "all" tool_progress_enabled = progress_mode != "off" - + # Queue for progress messages (thread-safe) progress_queue = queue.Queue() if tool_progress_enabled else None last_tool = [None] # Mutable container for tracking in closure - + def progress_callback(tool_name: str, preview: str = None, args: dict = None): """Callback invoked by agent when a tool is called.""" if not progress_queue: return - + # "new" mode: only report when tool changes if progress_mode == "new" and tool_name == last_tool[0]: return last_tool[0] = tool_name - + # Build progress message with primary argument preview tool_emojis = { "terminal": "💻", @@ -2508,27 +2536,28 @@ class GatewayRunner: "skill_manage": "📝", } emoji = tool_emojis.get(tool_name, "⚙️") - + # Verbose mode: show detailed arguments if progress_mode == "verbose" and args: import json as _json + args_str = _json.dumps(args, ensure_ascii=False, default=str) if len(args_str) > 200: args_str = args_str[:197] + "..." msg = f"{emoji} {tool_name}({list(args.keys())})\n{args_str}" progress_queue.put(msg) return - + if preview: # Truncate preview to keep messages clean if len(preview) > 80: preview = preview[:77] + "..." - msg = f"{emoji} {tool_name}: \"{preview}\"" + msg = f'{emoji} {tool_name}: "{preview}"' else: msg = f"{emoji} {tool_name}..." - + progress_queue.put(msg) - + # Background task to send progress messages # Accumulates tool lines into a single message that gets edited async def send_progress_messages(): @@ -2539,9 +2568,9 @@ class GatewayRunner: if not adapter: return - progress_lines = [] # Accumulated tool lines - progress_msg_id = None # ID of the progress message to edit - can_edit = True # False once an edit fails (platform doesn't support it) + progress_lines = [] # Accumulated tool lines + progress_msg_id = None # ID of the progress message to edit + can_edit = True # False once an edit fails (platform doesn't support it) while True: try: @@ -2601,12 +2630,12 @@ class GatewayRunner: except Exception as e: logger.error("Progress message error: %s", e) await asyncio.sleep(1) - + # We need to share the agent instance for interrupt support agent_holder = [None] # Mutable container for the agent instance result_holder = [None] # Mutable container for the result - tools_holder = [None] # Mutable container for the tool definitions - + tools_holder = [None] # Mutable container for the tool definitions + # Bridge sync step_callback → async hooks.emit for agent:step events _loop_for_step = asyncio.get_event_loop() _hooks_ref = self.hooks @@ -2614,13 +2643,16 @@ class GatewayRunner: def _step_callback_sync(iteration: int, tool_names: list) -> None: try: asyncio.run_coroutine_threadsafe( - _hooks_ref.emit("agent:step", { - "platform": source.platform.value if source.platform else "", - "user_id": source.user_id, - "session_id": session_id, - "iteration": iteration, - "tool_names": tool_names, - }), + _hooks_ref.emit( + "agent:step", + { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "session_id": session_id, + "iteration": iteration, + "tool_names": tool_names, + }, + ), _loop_for_step, ) except Exception as _e: @@ -2633,11 +2665,11 @@ class GatewayRunner: # Read from env var or use default (same as CLI) max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90")) - + # Map platform enum to the platform hint key the agent understands. # Platform.LOCAL ("local") maps to "cli"; others pass through as-is. platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value - + # Combine platform context with user-configured ephemeral system prompt combined_ephemeral = context_prompt or "" if self._ephemeral_system_prompt: @@ -2656,6 +2688,7 @@ class GatewayRunner: try: import yaml as _y + _cfg_path = _hermes_home / "config.yaml" if _cfg_path.exists(): with open(_cfg_path) as _f: @@ -2703,12 +2736,12 @@ class GatewayRunner: session_db=self._session_db, fallback_model=self._fallback_model, ) - + # Store agent reference for interrupt support agent_holder[0] = agent # Capture the full tool definitions for transcript logging - tools_holder[0] = agent.tools if hasattr(agent, 'tools') else None - + tools_holder[0] = agent.tools if hasattr(agent, "tools") else None + # Convert history to agent format. # Two cases: # 1. Normal path (from transcript): simple {role, content, timestamp} dicts @@ -2722,22 +2755,22 @@ class GatewayRunner: role = msg.get("role") if not role: continue - + # Skip metadata entries (tool definitions, session info) # -- these are for transcript logging, not for the LLM if role in ("session_meta",): continue - + # Skip system messages -- the agent rebuilds its own system prompt if role == "system": continue - + # Rich agent messages (tool_calls, tool results) must be passed # through intact so the API sees valid assistant→tool sequences has_tool_calls = "tool_calls" in msg has_tool_call_id = "tool_call_id" in msg is_tool_message = role == "tool" - + if has_tool_calls or has_tool_call_id or is_tool_message: clean_msg = {k: v for k, v in msg.items() if k != "timestamp"} agent_history.append(clean_msg) @@ -2750,7 +2783,7 @@ class GatewayRunner: mirror_src = msg.get("mirror_source", "another session") content = f"[Delivered from {mirror_src}] {content}" agent_history.append({"role": role, "content": content}) - + # Collect MEDIA paths already in history so we can exclude them # from the current turn's extraction. This is compression-safe: # even if the message list shrinks, we know which paths are old. @@ -2759,14 +2792,14 @@ class GatewayRunner: if _hm.get("role") in ("tool", "function"): _hc = _hm.get("content", "") if "MEDIA:" in _hc: - for _match in re.finditer(r'MEDIA:(\S+)', _hc): + for _match in re.finditer(r"MEDIA:(\S+)", _hc): _p = _match.group(1).strip().rstrip('",}') if _p: _history_media_paths.add(_p) - + result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id) result_holder[0] = result - + # Return final response, or a message if something went wrong final_response = result.get("final_response") if not final_response: @@ -2778,7 +2811,7 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), } - + # Scan tool results for MEDIA: tags that need to be delivered # as native audio/file attachments. The TTS tool embeds MEDIA: tags # in its JSON response, but the model's final text reply usually @@ -2796,13 +2829,13 @@ class GatewayRunner: if msg.get("role") in ("tool", "function"): content = msg.get("content", "") if "MEDIA:" in content: - for match in re.finditer(r'MEDIA:(\S+)', content): + for match in re.finditer(r"MEDIA:(\S+)", content): path = match.group(1).strip().rstrip('",}') if path and path not in _history_media_paths: media_tags.append(f"MEDIA:{path}") if "[[audio_as_voice]]" in content: has_voice_directive = True - + if media_tags: seen = set() unique_tags = [] @@ -2813,7 +2846,7 @@ class GatewayRunner: if has_voice_directive: unique_tags.insert(0, "[[audio_as_voice]]") final_response = final_response + "\n" + "\n".join(unique_tags) - + return { "final_response": final_response, "messages": result_holder[0].get("messages", []) if result_holder[0] else [], @@ -2821,12 +2854,12 @@ class GatewayRunner: "tools": tools_holder[0] or [], "history_offset": len(agent_history), } - + # Start progress message sender if enabled progress_task = None if tool_progress_enabled: progress_task = asyncio.create_task(send_progress_messages()) - + # Track this agent as running for this session (for interrupt support) # We do this in a callback after the agent is created async def track_agent(): @@ -2835,20 +2868,20 @@ class GatewayRunner: await asyncio.sleep(0.05) if session_key: self._running_agents[session_key] = agent_holder[0] - + tracking_task = asyncio.create_task(track_agent()) - + # Monitor for interrupts from the adapter (new messages arriving) async def monitor_for_interrupt(): adapter = self.adapters.get(source.platform) if not adapter: return - + chat_id = source.chat_id while True: await asyncio.sleep(0.2) # Check every 200ms # Check if adapter has a pending interrupt for this session - if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(chat_id): + if hasattr(adapter, "has_pending_interrupt") and adapter.has_pending_interrupt(chat_id): agent = agent_holder[0] if agent: pending_event = adapter.get_pending_message(chat_id) @@ -2856,18 +2889,18 @@ class GatewayRunner: logger.debug("Interrupt detected from adapter, signaling agent...") agent.interrupt(pending_text) break - + interrupt_monitor = asyncio.create_task(monitor_for_interrupt()) - + try: # Run in thread pool to not block loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, run_sync) - + # Check if we were interrupted and have a pending message result = result_holder[0] adapter = self.adapters.get(source.platform) - + # Get pending message from adapter if interrupted pending = None if result and result.get("interrupted") and adapter: @@ -2876,20 +2909,20 @@ class GatewayRunner: pending = pending_event.text elif result.get("interrupt_message"): pending = result.get("interrupt_message") - + if pending: logger.debug("Processing interrupted message: '%s...'", pending[:40]) - + # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). - if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions: + if adapter and hasattr(adapter, "_active_sessions") and source.chat_id in adapter._active_sessions: adapter._active_sessions[source.chat_id].clear() - + # Don't send the interrupted response to the user — it's just noise # like "Operation interrupted." They already know they sent a new # message, so go straight to processing it. - + # Now process the pending message with updated history updated_history = result.get("messages", history) return await self._run_agent( @@ -2898,19 +2931,19 @@ class GatewayRunner: history=updated_history, source=source, session_id=session_id, - session_key=session_key + session_key=session_key, ) finally: # Stop progress sender and interrupt monitor if progress_task: progress_task.cancel() interrupt_monitor.cancel() - + # Clean up tracking tracking_task.cancel() if session_key and session_key in self._running_agents: del self._running_agents[session_key] - + # Wait for cancelled tasks for task in [progress_task, interrupt_monitor, tracking_task]: if task: @@ -2918,14 +2951,14 @@ class GatewayRunner: await task except asyncio.CancelledError: pass - + return response def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int = 60): """ Background thread that ticks the cron scheduler at a regular interval. - + Runs inside the gateway process so cronjobs fire automatically without needing a separate `hermes cron daemon` or system cron entry. @@ -2933,10 +2966,10 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int image/audio/document cache once per hour. """ from cron.scheduler import tick as cron_tick - from gateway.platforms.base import cleanup_image_cache, cleanup_document_cache + from gateway.platforms.base import cleanup_document_cache, cleanup_image_cache - IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval - CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes + IMAGE_CACHE_EVERY = 60 # ticks — once per hour at default 60s interval + CHANNEL_DIR_EVERY = 5 # ticks — every 5 minutes logger.info("Cron ticker started (interval=%ds)", interval) tick_count = 0 @@ -2951,6 +2984,7 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int if tick_count % CHANNEL_DIR_EVERY == 0 and adapters: try: from gateway.channel_directory import build_channel_directory + build_channel_directory(adapters) except Exception as e: logger.debug("Channel directory refresh error: %s", e) @@ -2973,14 +3007,14 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int logger.info("Cron ticker stopped") -async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool: +async def start_gateway(config: GatewayConfig | None = None, replace: bool = False) -> bool: """ Start the gateway and run until interrupted. - + This is the main entry point for running the gateway. Returns True if the gateway ran successfully, False if it failed to start. A False return causes a non-zero exit code so systemd can auto-restart. - + Args: config: Optional gateway configuration override. replace: If True, kill any existing gateway instance before starting. @@ -2993,7 +3027,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = # setups (each profile using a distinct HERMES_HOME) will naturally # allow concurrent instances without tripping this guard. import time as _time + from gateway.status import get_running_pid, remove_pid_file + existing_pid = get_running_pid() if existing_pid is not None and existing_pid != os.getpid(): if replace: @@ -3035,7 +3071,8 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = logger.error( "Another gateway instance is already running (PID %d, HERMES_HOME=%s). " "Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.", - existing_pid, hermes_home, + existing_pid, + hermes_home, ) print( f"\n❌ Gateway already running (PID {existing_pid}).\n" @@ -3048,57 +3085,61 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = # Sync bundled skills on gateway start (fast -- skips unchanged) try: from tools.skills_sync import sync_skills + sync_skills(quiet=True) except Exception: pass # Configure rotating file log so gateway output is persisted for debugging - log_dir = _hermes_home / 'logs' + log_dir = _hermes_home / "logs" log_dir.mkdir(parents=True, exist_ok=True) file_handler = RotatingFileHandler( - log_dir / 'gateway.log', + log_dir / "gateway.log", maxBytes=5 * 1024 * 1024, backupCount=3, ) from agent.redact import RedactingFormatter - file_handler.setFormatter(RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s')) + + file_handler.setFormatter(RedactingFormatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) logging.getLogger().addHandler(file_handler) logging.getLogger().setLevel(logging.INFO) # Separate errors-only log for easy debugging error_handler = RotatingFileHandler( - log_dir / 'errors.log', + log_dir / "errors.log", maxBytes=2 * 1024 * 1024, backupCount=2, ) error_handler.setLevel(logging.WARNING) - error_handler.setFormatter(RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s')) + error_handler.setFormatter(RedactingFormatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) logging.getLogger().addHandler(error_handler) runner = GatewayRunner(config) - + # Set up signal handlers def signal_handler(): asyncio.create_task(runner.stop()) - + loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, signal_handler) except NotImplementedError: pass - + # Start the gateway success = await runner.start() if not success: return False - + # Write PID file so CLI can detect gateway is running import atexit - from gateway.status import write_pid_file, remove_pid_file + + from gateway.status import remove_pid_file, write_pid_file + write_pid_file() atexit.register(remove_pid_file) - + # Start background cron ticker so scheduled jobs fire automatically cron_stop = threading.Event() cron_thread = threading.Thread( @@ -3109,10 +3150,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = name="cron-ticker", ) cron_thread.start() - + # Wait for shutdown await runner.wait_for_shutdown() - + # Stop cron ticker cleanly cron_stop.set() cron_thread.join(timeout=5) @@ -3120,6 +3161,7 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = # Close MCP server connections try: from tools.mcp_tool import shutdown_mcp_servers + shutdown_mcp_servers() except Exception: pass @@ -3130,20 +3172,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = def main(): """CLI entry point for the gateway.""" import argparse - + parser = argparse.ArgumentParser(description="Hermes Gateway - Multi-platform messaging") parser.add_argument("--config", "-c", help="Path to gateway config file") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") - + args = parser.parse_args() - + config = None if args.config: import json + with open(args.config) as f: data = json.load(f) config = GatewayConfig.from_dict(data) - + # Run the gateway - exit with code 1 if no platforms connected, # so systemd Restart=on-failure will retry on transient errors (e.g. DNS) success = asyncio.run(start_gateway(config)) diff --git a/gateway/session.py b/gateway/session.py index dfe3f12efd..79ca086fbf 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -8,22 +8,20 @@ Handles: - Dynamic system prompt injection (agent knows its context) """ -import logging -import os import json +import logging import uuid -from pathlib import Path +from dataclasses import dataclass from datetime import datetime, timedelta -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Any +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) from .config import ( - Platform, GatewayConfig, - SessionResetPolicy, HomeChannel, + Platform, ) @@ -31,29 +29,30 @@ from .config import ( class SessionSource: """ Describes where a message originated from. - + This information is used to: 1. Route responses back to the right place 2. Inject context into the system prompt 3. Track origin for cron job delivery """ + platform: Platform chat_id: str - chat_name: Optional[str] = None + chat_name: str | None = None chat_type: str = "dm" # "dm", "group", "channel", "thread" - user_id: Optional[str] = None - user_name: Optional[str] = None - thread_id: Optional[str] = None # For forum topics, Discord threads, etc. - chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack) - user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number) - chat_id_alt: Optional[str] = None # Signal group internal ID - + user_id: str | None = None + user_name: str | None = None + thread_id: str | None = None # For forum topics, Discord threads, etc. + chat_topic: str | None = None # Channel topic/description (Discord, Slack) + user_id_alt: str | None = None # Signal UUID (alternative to phone number) + chat_id_alt: str | None = None # Signal group internal ID + @property def description(self) -> str: """Human-readable description of the source.""" if self.platform == Platform.LOCAL: return "CLI terminal" - + parts = [] if self.chat_type == "dm": parts.append(f"DM with {self.user_name or self.user_id or 'user'}") @@ -63,13 +62,13 @@ class SessionSource: parts.append(f"channel: {self.chat_name or self.chat_id}") else: parts.append(self.chat_name or self.chat_id) - + if self.thread_id: parts.append(f"thread: {self.thread_id}") - + return ", ".join(parts) - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: d = { "platform": self.platform.value, "chat_id": self.chat_id, @@ -85,9 +84,9 @@ class SessionSource: if self.chat_id_alt: d["chat_id_alt"] = self.chat_id_alt return d - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SessionSource": + def from_dict(cls, data: dict[str, Any]) -> "SessionSource": return cls( platform=Platform(data["platform"]), chat_id=str(data["chat_id"]), @@ -100,7 +99,7 @@ class SessionSource: user_id_alt=data.get("user_id_alt"), chat_id_alt=data.get("chat_id_alt"), ) - + @classmethod def local_cli(cls) -> "SessionSource": """Create a source representing the local CLI.""" @@ -116,29 +115,28 @@ class SessionSource: class SessionContext: """ Full context for a session, used for dynamic system prompt injection. - + The agent receives this information to understand: - Where messages are coming from - What platforms are available - Where it can deliver scheduled task outputs """ + source: SessionSource - connected_platforms: List[Platform] - home_channels: Dict[Platform, HomeChannel] - + connected_platforms: list[Platform] + home_channels: dict[Platform, HomeChannel] + # Session metadata session_key: str = "" session_id: str = "" - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - def to_dict(self) -> Dict[str, Any]: + created_at: datetime | None = None + updated_at: datetime | None = None + + def to_dict(self) -> dict[str, Any]: return { "source": self.source.to_dict(), "connected_platforms": [p.value for p in self.connected_platforms], - "home_channels": { - p.value: hc.to_dict() for p, hc in self.home_channels.items() - }, + "home_channels": {p.value: hc.to_dict() for p, hc in self.home_channels.items()}, "session_key": self.session_key, "session_id": self.session_id, "created_at": self.created_at.isoformat() if self.created_at else None, @@ -149,7 +147,7 @@ class SessionContext: def build_session_context_prompt(context: SessionContext) -> str: """ Build the dynamic system prompt section that tells the agent about its context. - + This is injected into the system prompt so the agent knows: - Where messages are coming from - What platforms are connected @@ -159,14 +157,14 @@ def build_session_context_prompt(context: SessionContext) -> str: "## Current Session Context", "", ] - + # Source info platform_name = context.source.platform.value.title() if context.source.platform == Platform.LOCAL: lines.append(f"**Source:** {platform_name} (the machine running this agent)") else: lines.append(f"**Source:** {platform_name} ({context.source.description})") - + # Channel topic (if available - provides context about the channel's purpose) if context.source.chat_topic: lines.append(f"**Channel Topic:** {context.source.chat_topic}") @@ -176,43 +174,43 @@ def build_session_context_prompt(context: SessionContext) -> str: lines.append(f"**User:** {context.source.user_name}") elif context.source.user_id: lines.append(f"**User ID:** {context.source.user_id}") - + # Connected platforms platforms_list = ["local (files on this machine)"] for p in context.connected_platforms: if p != Platform.LOCAL: platforms_list.append(f"{p.value}: Connected ✓") - + lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}") - + # Home channels if context.home_channels: lines.append("") lines.append("**Home Channels (default destinations):**") for platform, home in context.home_channels.items(): lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})") - + # Delivery options for scheduled tasks lines.append("") lines.append("**Delivery options for scheduled tasks:**") - + # Origin delivery if context.source.platform == Platform.LOCAL: - lines.append("- `\"origin\"` → Local output (saved to files)") + lines.append('- `"origin"` → Local output (saved to files)') else: - lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})") - + lines.append(f'- `"origin"` → Back to this chat ({context.source.chat_name or context.source.chat_id})') + # Local always available - lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)") - + lines.append('- `"local"` → Save to local files only (~/.hermes/cron/output/)') + # Platform home channels for platform, home in context.home_channels.items(): - lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})") - + lines.append(f'- `"{platform.value}"` → Home channel ({home.name})') + # Note about explicit targeting lines.append("") - lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*") - + lines.append('*For explicit targeting, use `"platform:chat_id"` format if the user provides a specific chat ID.*') + return "\n".join(lines) @@ -220,32 +218,33 @@ def build_session_context_prompt(context: SessionContext) -> str: class SessionEntry: """ Entry in the session store. - + Maps a session key to its current session ID and metadata. """ + session_key: str session_id: str created_at: datetime updated_at: datetime - + # Origin metadata for delivery routing - origin: Optional[SessionSource] = None - + origin: SessionSource | None = None + # Display metadata - display_name: Optional[str] = None - platform: Optional[Platform] = None + display_name: str | None = None + platform: Platform | None = None chat_type: str = "dm" - + # Token tracking input_tokens: int = 0 output_tokens: int = 0 total_tokens: int = 0 - + # Set when a session was created because the previous one expired; # consumed once by the message handler to inject a notice into context was_auto_reset: bool = False - - def to_dict(self) -> Dict[str, Any]: + + def to_dict(self) -> dict[str, Any]: result = { "session_key": self.session_key, "session_id": self.session_id, @@ -261,20 +260,20 @@ class SessionEntry: if self.origin: result["origin"] = self.origin.to_dict() return result - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry": + def from_dict(cls, data: dict[str, Any]) -> "SessionEntry": origin = None if "origin" in data and data["origin"]: origin = SessionSource.from_dict(data["origin"]) - + platform = None if data.get("platform"): try: platform = Platform(data["platform"]) except ValueError: pass - + return cls( session_key=data["session_key"], session_id=data["session_id"], @@ -307,66 +306,65 @@ def build_session_key(source: SessionSource) -> str: class SessionStore: """ Manages session storage and retrieval. - + Uses SQLite (via SessionDB) for session metadata and message transcripts. Falls back to legacy JSONL files if SQLite is unavailable. """ - - def __init__(self, sessions_dir: Path, config: GatewayConfig, - has_active_processes_fn=None, - on_auto_reset=None): + + def __init__(self, sessions_dir: Path, config: GatewayConfig, has_active_processes_fn=None, on_auto_reset=None): self.sessions_dir = sessions_dir self.config = config - self._entries: Dict[str, SessionEntry] = {} + self._entries: dict[str, SessionEntry] = {} self._loaded = False self._has_active_processes_fn = has_active_processes_fn # on_auto_reset is deprecated — memory flush now runs proactively # via the background session expiry watcher in GatewayRunner. self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher - + # Initialize SQLite session database self._db = None try: from hermes_state import SessionDB + self._db = SessionDB() except Exception as e: print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}") - + def _ensure_loaded(self) -> None: """Load sessions index from disk if not already loaded.""" if self._loaded: return - + self.sessions_dir.mkdir(parents=True, exist_ok=True) sessions_file = self.sessions_dir / "sessions.json" - + if sessions_file.exists(): try: - with open(sessions_file, "r", encoding="utf-8") as f: + with open(sessions_file, encoding="utf-8") as f: data = json.load(f) for key, entry_data in data.items(): self._entries[key] = SessionEntry.from_dict(entry_data) except Exception as e: print(f"[gateway] Warning: Failed to load sessions: {e}") - + self._loaded = True - + def _save(self) -> None: """Save sessions index to disk (kept for session key -> ID mapping).""" self.sessions_dir.mkdir(parents=True, exist_ok=True) sessions_file = self.sessions_dir / "sessions.json" - + data = {key: entry.to_dict() for key, entry in self._entries.items()} with open(sessions_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=2) - + def _generate_session_key(self, source: SessionSource) -> str: """Generate a session key from a source.""" return build_session_key(source) - + def _is_session_expired(self, entry: SessionEntry) -> bool: """Check if a session has expired based on its reset policy. - + Works from the entry alone — no SessionSource needed. Used by the background expiry watcher to proactively flush memories. Sessions with active background processes are never considered expired. @@ -393,7 +391,9 @@ class SessionStore: if policy.mode in ("daily", "both"): today_reset = now.replace( hour=policy.at_hour, - minute=0, second=0, microsecond=0, + minute=0, + second=0, + microsecond=0, ) if now.hour < policy.at_hour: today_reset -= timedelta(days=1) @@ -405,7 +405,7 @@ class SessionStore: def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool: """ Check if a session should be reset based on policy. - + Sessions with active background processes are never reset. """ if self._has_active_processes_fn: @@ -413,36 +413,28 @@ class SessionStore: if self._has_active_processes_fn(session_key): return False - policy = self.config.get_reset_policy( - platform=source.platform, - session_type=source.chat_type - ) - + policy = self.config.get_reset_policy(platform=source.platform, session_type=source.chat_type) + if policy.mode == "none": return False - + now = datetime.now() - + if policy.mode in ("idle", "both"): idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes) if now > idle_deadline: return True - + if policy.mode in ("daily", "both"): - today_reset = now.replace( - hour=policy.at_hour, - minute=0, - second=0, - microsecond=0 - ) + today_reset = now.replace(hour=policy.at_hour, minute=0, second=0, microsecond=0) if now.hour < policy.at_hour: today_reset -= timedelta(days=1) - + if entry.updated_at < today_reset: return True - + return False - + def has_any_sessions(self) -> bool: """Check if any sessions have ever been created (across all platforms). @@ -463,26 +455,22 @@ class SessionStore: # This covers the rare case where the DB is unavailable. self._ensure_loaded() return len(self._entries) > 1 - - def get_or_create_session( - self, - source: SessionSource, - force_new: bool = False - ) -> SessionEntry: + + def get_or_create_session(self, source: SessionSource, force_new: bool = False) -> SessionEntry: """ Get an existing session or create a new one. - + Evaluates reset policy to determine if the existing session is stale. Creates a session record in SQLite when a new session starts. """ self._ensure_loaded() - + session_key = self._generate_session_key(source) now = datetime.now() - + if session_key in self._entries and not force_new: entry = self._entries[session_key] - + if not self._should_reset(entry, source): entry.updated_at = now self._save() @@ -500,10 +488,10 @@ class SessionStore: logger.debug("Session DB operation failed: %s", e) else: was_auto_reset = False - + # Create new session session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - + entry = SessionEntry( session_key=session_key, session_id=session_id, @@ -515,10 +503,10 @@ class SessionStore: chat_type=source.chat_type, was_auto_reset=was_auto_reset, ) - + self._entries[session_key] = entry self._save() - + # Create session in SQLite if self._db: try: @@ -529,18 +517,13 @@ class SessionStore: ) except Exception as e: print(f"[gateway] Warning: Failed to create SQLite session: {e}") - + return entry - - def update_session( - self, - session_key: str, - input_tokens: int = 0, - output_tokens: int = 0 - ) -> None: + + def update_session(self, session_key: str, input_tokens: int = 0, output_tokens: int = 0) -> None: """Update a session's metadata after an interaction.""" self._ensure_loaded() - + if session_key in self._entries: entry = self._entries[session_key] entry.updated_at = datetime.now() @@ -548,34 +531,32 @@ class SessionStore: entry.output_tokens += output_tokens entry.total_tokens = entry.input_tokens + entry.output_tokens self._save() - + if self._db: try: - self._db.update_token_counts( - entry.session_id, input_tokens, output_tokens - ) + self._db.update_token_counts(entry.session_id, input_tokens, output_tokens) except Exception as e: logger.debug("Session DB operation failed: %s", e) - - def reset_session(self, session_key: str) -> Optional[SessionEntry]: + + def reset_session(self, session_key: str) -> SessionEntry | None: """Force reset a session, creating a new session ID.""" self._ensure_loaded() - + if session_key not in self._entries: return None - + old_entry = self._entries[session_key] - + # End old session in SQLite if self._db: try: self._db.end_session(old_entry.session_id, "session_reset") except Exception as e: logger.debug("Session DB operation failed: %s", e) - + now = datetime.now() session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - + new_entry = SessionEntry( session_key=session_key, session_id=session_id, @@ -586,10 +567,10 @@ class SessionStore: platform=old_entry.platform, chat_type=old_entry.chat_type, ) - + self._entries[session_key] = new_entry self._save() - + # Create new session in SQLite if self._db: try: @@ -600,10 +581,10 @@ class SessionStore: ) except Exception as e: logger.debug("Session DB operation failed: %s", e) - + return new_entry - def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]: + def switch_session(self, session_key: str, target_session_id: str) -> SessionEntry | None: """Switch a session key to point at an existing session ID. Used by ``/resume`` to restore a previously-named session. @@ -645,25 +626,25 @@ class SessionStore: self._save() return new_entry - def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]: + def list_sessions(self, active_minutes: int | None = None) -> list[SessionEntry]: """List all sessions, optionally filtered by activity.""" self._ensure_loaded() - + entries = list(self._entries.values()) - + if active_minutes is not None: cutoff = datetime.now() - timedelta(minutes=active_minutes) entries = [e for e in entries if e.updated_at >= cutoff] - + entries.sort(key=lambda e: e.updated_at, reverse=True) - + return entries - + def get_transcript_path(self, session_id: str) -> Path: """Get the path to a session's legacy transcript file.""" return self.sessions_dir / f"{session_id}.jsonl" - - def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None: + + def append_to_transcript(self, session_id: str, message: dict[str, Any]) -> None: """Append a message to a session's transcript (SQLite + legacy JSONL).""" # Write to SQLite if self._db: @@ -678,15 +659,15 @@ class SessionStore: ) except Exception as e: logger.debug("Session DB operation failed: %s", e) - + # Also write legacy JSONL (keeps existing tooling working during transition) transcript_path = self.get_transcript_path(session_id) with open(transcript_path, "a", encoding="utf-8") as f: f.write(json.dumps(message, ensure_ascii=False) + "\n") - - def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None: + + def rewrite_transcript(self, session_id: str, messages: list[dict[str, Any]]) -> None: """Replace the entire transcript for a session with new messages. - + Used by /retry, /undo, and /compress to persist modified conversation history. Rewrites both SQLite and legacy JSONL storage. """ @@ -705,14 +686,14 @@ class SessionStore: ) except Exception as e: logger.debug("Failed to rewrite transcript in DB: %s", e) - + # JSONL: overwrite the file transcript_path = self.get_transcript_path(session_id) with open(transcript_path, "w", encoding="utf-8") as f: for msg in messages: f.write(json.dumps(msg, ensure_ascii=False) + "\n") - def load_transcript(self, session_id: str) -> List[Dict[str, Any]]: + def load_transcript(self, session_id: str) -> list[dict[str, Any]]: """Load all messages from a session's transcript.""" # Try SQLite first if self._db: @@ -722,51 +703,49 @@ class SessionStore: return messages except Exception as e: logger.debug("Could not load messages from DB: %s", e) - + # Fall back to legacy JSONL transcript_path = self.get_transcript_path(session_id) - + if not transcript_path.exists(): return [] - + messages = [] - with open(transcript_path, "r", encoding="utf-8") as f: + with open(transcript_path, encoding="utf-8") as f: for line in f: line = line.strip() if line: messages.append(json.loads(line)) - + return messages def build_session_context( - source: SessionSource, - config: GatewayConfig, - session_entry: Optional[SessionEntry] = None + source: SessionSource, config: GatewayConfig, session_entry: SessionEntry | None = None ) -> SessionContext: """ Build a full session context from a source and config. - + This is used to inject context into the agent's system prompt. """ connected = config.get_connected_platforms() - + home_channels = {} for platform in connected: home = config.get_home_channel(platform) if home: home_channels[platform] = home - + context = SessionContext( source=source, connected_platforms=connected, home_channels=home_channels, ) - + if session_entry: context.session_key = session_entry.session_key context.session_id = session_entry.session_id context.created_at = session_entry.created_at context.updated_at = session_entry.updated_at - + return context diff --git a/gateway/status.py b/gateway/status.py index 78d71947fd..c7423105c2 100644 --- a/gateway/status.py +++ b/gateway/status.py @@ -13,7 +13,6 @@ concurrently under distinct configurations). import os from pathlib import Path -from typing import Optional def _get_pid_path() -> Path: @@ -37,7 +36,7 @@ def remove_pid_file() -> None: pass -def get_running_pid() -> Optional[int]: +def get_running_pid() -> int | None: """Return the PID of a running gateway instance, or ``None``. Checks the PID file and verifies the process is actually alive. diff --git a/gateway/sticker_cache.py b/gateway/sticker_cache.py index 597f672ef8..1af9def9fd 100644 --- a/gateway/sticker_cache.py +++ b/gateway/sticker_cache.py @@ -12,8 +12,6 @@ import json import os import time from pathlib import Path -from typing import Optional - CACHE_PATH = Path(os.path.expanduser("~/.hermes/sticker_cache.json")) @@ -43,7 +41,7 @@ def _save_cache(cache: dict) -> None: ) -def get_cached_description(file_unique_id: str) -> Optional[dict]: +def get_cached_description(file_unique_id: str) -> dict | None: """ Look up a cached sticker description. @@ -92,11 +90,11 @@ def build_sticker_injection( """ context = "" if set_name and emoji: - context = f" {emoji} from \"{set_name}\"" + context = f' {emoji} from "{set_name}"' elif emoji: context = f" {emoji}" - return f"[The user sent a sticker{context}~ It shows: \"{description}\" (=^.w.^=)]" + return f'[The user sent a sticker{context}~ It shows: "{description}" (=^.w.^=)]' def build_animated_sticker_injection(emoji: str = "") -> str: diff --git a/hermes_cli/__init__.py b/hermes_cli/__init__.py index 7e647afc35..c07146d0ed 100644 --- a/hermes_cli/__init__.py +++ b/hermes_cli/__init__.py @@ -5,7 +5,7 @@ Provides subcommands for: - hermes chat - Interactive chat (same as ./hermes) - hermes gateway - Run gateway in foreground - hermes gateway start - Start gateway service -- hermes gateway stop - Stop gateway service +- hermes gateway stop - Stop gateway service - hermes setup - Interactive setup wizard - hermes status - Show status of all components - hermes cron - Manage cron jobs diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 209f729595..87bc861f0e 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -15,27 +15,25 @@ Architecture: from __future__ import annotations +import base64 +import hashlib import json import logging import os -import shutil import stat -import base64 -import hashlib -import subprocess import time import uuid import webbrowser from contextlib import contextmanager from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import httpx import yaml -from hermes_cli.config import get_hermes_home, get_config_path +from hermes_cli.config import get_config_path, get_hermes_home from hermes_constants import OPENROUTER_BASE_URL logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ DEFAULT_NOUS_INFERENCE_URL = "https://inference-api.nousresearch.com/v1" DEFAULT_NOUS_CLIENT_ID = "hermes-cli" DEFAULT_NOUS_SCOPE = "inference:mint_agent_key" DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes -ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry -DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s +ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry +DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token" @@ -70,9 +68,11 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # Provider Registry # ============================================================================= + @dataclass class ProviderConfig: """Describes a known inference provider.""" + id: str name: str auth_type: str # "oauth_device_code", "oauth_external", or "api_key" @@ -80,14 +80,14 @@ class ProviderConfig: inference_base_url: str = "" client_id: str = "" scope: str = "" - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) # For API-key providers: env vars to check (in priority order) api_key_env_vars: tuple = () # Optional env var for base URL override base_url_env_var: str = "" -PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { +PROVIDER_REGISTRY: dict[str, ProviderConfig] = { "nous": ProviderConfig( id="nous", name="Nous Portal", @@ -172,14 +172,14 @@ def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) -> ZAI_ENDPOINTS = [ # (id, base_url, default_model, label) - ("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"), - ("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"), - ("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"), - ("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"), + ("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"), + ("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"), + ("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"), + ("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"), ] -def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]: +def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> dict[str, str] | None: """Probe z.ai endpoints to find one that accepts this API key. Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the @@ -219,6 +219,7 @@ def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str # Error Types # ============================================================================= + class AuthError(RuntimeError): """Structured auth error with UX mapping hints.""" @@ -227,7 +228,7 @@ class AuthError(RuntimeError): message: str, *, provider: str = "", - code: Optional[str] = None, + code: str | None = None, relogin_required: bool = False, ) -> None: super().__init__(message) @@ -245,16 +246,10 @@ def format_auth_error(error: Exception) -> str: return f"{error} Run `hermes model` to re-authenticate." if error.code == "subscription_required": - return ( - "No active paid subscription found on Nous Portal. " - "Please purchase/activate a subscription, then retry." - ) + return "No active paid subscription found on Nous Portal. Please purchase/activate a subscription, then retry." if error.code == "insufficient_credits": - return ( - "Subscription credits are exhausted. " - "Top up/renew credits in Nous Portal, then retry." - ) + return "Subscription credits are exhausted. Top up/renew credits in Nous Portal, then retry." if error.code == "temporarily_unavailable": return f"{error} Please retry in a few seconds." @@ -262,7 +257,7 @@ def format_auth_error(error: Exception) -> str: return str(error) -def _token_fingerprint(token: Any) -> Optional[str]: +def _token_fingerprint(token: Any) -> str | None: """Return a short hash fingerprint for telemetry without leaking token bytes.""" if not isinstance(token, str): return None @@ -277,10 +272,10 @@ def _oauth_trace_enabled() -> bool: return raw in {"1", "true", "yes", "on"} -def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None: +def _oauth_trace(event: str, *, sequence_id: str | None = None, **fields: Any) -> None: if not _oauth_trace_enabled(): return - payload: Dict[str, Any] = {"event": event} + payload: dict[str, Any] = {"event": event} if sequence_id: payload["sequence_id"] = sequence_id payload.update(fields) @@ -291,6 +286,7 @@ def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any # Auth Store — persistence layer for ~/.hermes/auth.json # ============================================================================= + def _auth_file_path() -> Path: return get_hermes_home() / "auth.json" @@ -326,7 +322,7 @@ def _auth_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS): fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) -def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]: +def _load_auth_store(auth_file: Path | None = None) -> dict[str, Any]: auth_file = auth_file or _auth_file_path() if not auth_file.exists(): return {"version": AUTH_STORE_VERSION, "providers": {}} @@ -345,17 +341,16 @@ def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]: providers = {} if "nous_portal" in systems: providers["nous"] = systems["nous_portal"] - return {"version": AUTH_STORE_VERSION, "providers": providers, - "active_provider": "nous" if providers else None} + return {"version": AUTH_STORE_VERSION, "providers": providers, "active_provider": "nous" if providers else None} return {"version": AUTH_STORE_VERSION, "providers": {}} -def _save_auth_store(auth_store: Dict[str, Any]) -> Path: +def _save_auth_store(auth_store: dict[str, Any]) -> Path: auth_file = _auth_file_path() auth_file.parent.mkdir(parents=True, exist_ok=True) auth_store["version"] = AUTH_STORE_VERSION - auth_store["updated_at"] = datetime.now(timezone.utc).isoformat() + auth_store["updated_at"] = datetime.now(UTC).isoformat() payload = json.dumps(auth_store, indent=2) + "\n" tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}") try: @@ -387,7 +382,7 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path: return auth_file -def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Optional[Dict[str, Any]]: +def _load_provider_state(auth_store: dict[str, Any], provider_id: str) -> dict[str, Any] | None: providers = auth_store.get("providers") if not isinstance(providers, dict): return None @@ -395,7 +390,7 @@ def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Option return dict(state) if isinstance(state, dict) else None -def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Dict[str, Any]) -> None: +def _save_provider_state(auth_store: dict[str, Any], provider_id: str, state: dict[str, Any]) -> None: providers = auth_store.setdefault("providers", {}) if not isinstance(providers, dict): auth_store["providers"] = {} @@ -404,19 +399,19 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di auth_store["active_provider"] = provider_id -def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]: +def get_provider_auth_state(provider_id: str) -> dict[str, Any] | None: """Return persisted auth state for a provider, or None.""" auth_store = _load_auth_store() return _load_provider_state(auth_store, provider_id) -def get_active_provider() -> Optional[str]: +def get_active_provider() -> str | None: """Return the currently active provider ID from auth store.""" auth_store = _load_auth_store() return auth_store.get("active_provider") -def clear_provider_auth(provider_id: Optional[str] = None) -> bool: +def clear_provider_auth(provider_id: str | None = None) -> bool: """ Clear auth state for a provider. Used by `hermes logout`. If provider_id is None, clears the active provider. @@ -455,11 +450,12 @@ def deactivate_provider() -> None: # Provider Resolution — picks which provider to use # ============================================================================= + def resolve_provider( - requested: Optional[str] = None, + requested: str | None = None, *, - explicit_api_key: Optional[str] = None, - explicit_base_url: Optional[str] = None, + explicit_api_key: str | None = None, + explicit_base_url: str | None = None, ) -> str: """ Determine which inference provider to use. @@ -475,9 +471,14 @@ def resolve_provider( # Normalize provider aliases _PROVIDER_ALIASES = { - "glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai", - "kimi": "kimi-coding", "moonshot": "kimi-coding", - "minimax-china": "minimax-cn", "minimax_cn": "minimax-cn", + "glm": "zai", + "z-ai": "zai", + "z.ai": "zai", + "zhipu": "zai", + "kimi": "kimi-coding", + "moonshot": "kimi-coding", + "minimax-china": "minimax-cn", + "minimax_cn": "minimax-cn", } normalized = _PROVIDER_ALIASES.get(normalized, normalized) @@ -524,7 +525,8 @@ def resolve_provider( # Timestamp / TTL helpers # ============================================================================= -def _parse_iso_timestamp(value: Any) -> Optional[float]: + +def _parse_iso_timestamp(value: Any) -> float | None: if not isinstance(value, str) or not value: return None text = value.strip() @@ -537,7 +539,7 @@ def _parse_iso_timestamp(value: Any) -> Optional[float]: except Exception: return None if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) + parsed = parsed.replace(tzinfo=UTC) return parsed.timestamp() @@ -556,14 +558,14 @@ def _coerce_ttl_seconds(expires_in: Any) -> int: return max(0, ttl) -def _optional_base_url(value: Any) -> Optional[str]: +def _optional_base_url(value: Any) -> str | None: if not isinstance(value, str): return None cleaned = value.strip().rstrip("/") return cleaned if cleaned else None -def _decode_jwt_claims(token: Any) -> Dict[str, Any]: +def _decode_jwt_claims(token: Any) -> dict[str, Any]: if not isinstance(token, str) or token.count(".") != 2: return {} payload = token.split(".")[1] @@ -588,6 +590,7 @@ def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> boo # SSH / remote session detection # ============================================================================= + def _is_remote_session() -> bool: """Detect if running in an SSH session where webbrowser.open() won't work.""" return bool(os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY")) @@ -601,9 +604,10 @@ def _is_remote_session() -> bool: # where one app's refresh invalidates the other's session. # ============================================================================= -def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]: + +def _read_codex_tokens(*, _lock: bool = True) -> dict[str, Any]: """Read Codex OAuth tokens from Hermes auth store (~/.hermes/auth.json). - + Returns dict with 'tokens' (access_token, refresh_token) and 'last_refresh'. Raises AuthError if no Codex tokens are stored. """ @@ -650,10 +654,10 @@ def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]: } -def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None: +def _save_codex_tokens(tokens: dict[str, str], last_refresh: str = None) -> None: """Save Codex OAuth tokens to Hermes auth store (~/.hermes/auth.json).""" if last_refresh is None: - last_refresh = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + last_refresh = datetime.now(UTC).isoformat().replace("+00:00", "Z") with _auth_store_lock(): auth_store = _load_auth_store() state = _load_provider_state(auth_store, "openai-codex") or {} @@ -665,11 +669,11 @@ def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None def _refresh_codex_auth_tokens( - tokens: Dict[str, str], + tokens: dict[str, str], timeout_seconds: float, -) -> Dict[str, str]: +) -> dict[str, str]: """Refresh Codex access token using the refresh token. - + Saves the new tokens to Hermes auth store automatically. """ refresh_token = tokens.get("refresh_token") @@ -746,9 +750,9 @@ def _refresh_codex_auth_tokens( return updated_tokens -def _import_codex_cli_tokens() -> Optional[Dict[str, str]]: +def _import_codex_cli_tokens() -> dict[str, str] | None: """Try to read tokens from ~/.codex/auth.json (Codex CLI shared file). - + Returns tokens dict if valid, None otherwise. Does NOT write to the shared file. """ codex_home = os.getenv("CODEX_HOME", "").strip() @@ -774,7 +778,7 @@ def resolve_codex_runtime_credentials( force_refresh: bool = False, refresh_if_expiring: bool = True, refresh_skew_seconds: int = CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Resolve runtime credentials from Hermes's own Codex token store.""" try: data = _read_codex_tokens() @@ -817,10 +821,7 @@ def resolve_codex_runtime_credentials( tokens = _refresh_codex_auth_tokens(tokens, refresh_timeout_seconds) access_token = str(tokens.get("access_token", "") or "").strip() - base_url = ( - os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") - or DEFAULT_CODEX_BASE_URL - ) + base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL return { "provider": "openai-codex", @@ -836,24 +837,19 @@ def resolve_codex_runtime_credentials( # TLS verification helper # ============================================================================= + def _resolve_verify( *, - insecure: Optional[bool] = None, - ca_bundle: Optional[str] = None, - auth_state: Optional[Dict[str, Any]] = None, + insecure: bool | None = None, + ca_bundle: str | None = None, + auth_state: dict[str, Any] | None = None, ) -> bool | str: tls_state = auth_state.get("tls") if isinstance(auth_state, dict) else {} tls_state = tls_state if isinstance(tls_state, dict) else {} - effective_insecure = ( - bool(insecure) if insecure is not None - else bool(tls_state.get("insecure", False)) - ) + effective_insecure = bool(insecure) if insecure is not None else bool(tls_state.get("insecure", False)) effective_ca = ( - ca_bundle - or tls_state.get("ca_bundle") - or os.getenv("HERMES_CA_BUNDLE") - or os.getenv("SSL_CERT_FILE") + ca_bundle or tls_state.get("ca_bundle") or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE") ) if effective_insecure: @@ -867,12 +863,13 @@ def _resolve_verify( # OAuth Device Code Flow — generic, parameterized by provider # ============================================================================= + def _request_device_code( client: httpx.Client, portal_base_url: str, client_id: str, - scope: Optional[str], -) -> Dict[str, Any]: + scope: str | None, +) -> dict[str, Any]: """POST to the device code endpoint. Returns device_code, user_code, etc.""" response = client.post( f"{portal_base_url}/api/oauth/device/code", @@ -885,8 +882,12 @@ def _request_device_code( data = response.json() required_fields = [ - "device_code", "user_code", "verification_uri", - "verification_uri_complete", "expires_in", "interval", + "device_code", + "user_code", + "verification_uri", + "verification_uri_complete", + "expires_in", + "interval", ] missing = [f for f in required_fields if f not in data] if missing: @@ -901,7 +902,7 @@ def _poll_for_token( device_code: str, expires_in: int, poll_interval: int, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Poll the token endpoint until the user approves or the code expires.""" deadline = time.time() + max(1, expires_in) current_interval = max(1, min(poll_interval, DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS)) @@ -947,13 +948,14 @@ def _poll_for_token( # Nous Portal — token refresh, agent key minting, model discovery # ============================================================================= + def _refresh_access_token( *, client: httpx.Client, portal_base_url: str, client_id: str, refresh_token: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: response = client.post( f"{portal_base_url}/api/oauth/token", data={ @@ -966,15 +968,15 @@ def _refresh_access_token( if response.status_code == 200: payload = response.json() if "access_token" not in payload: - raise AuthError("Refresh response missing access_token", - provider="nous", code="invalid_token", relogin_required=True) + raise AuthError( + "Refresh response missing access_token", provider="nous", code="invalid_token", relogin_required=True + ) return payload try: error_payload = response.json() except Exception as exc: - raise AuthError("Refresh token exchange failed", - provider="nous", relogin_required=True) from exc + raise AuthError("Refresh token exchange failed", provider="nous", relogin_required=True) from exc code = str(error_payload.get("error", "invalid_grant")) description = str(error_payload.get("error_description") or "Refresh token exchange failed") @@ -988,7 +990,7 @@ def _mint_agent_key( portal_base_url: str, access_token: str, min_ttl_seconds: int, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Mint (or reuse) a short-lived inference API key.""" response = client.post( f"{portal_base_url}/api/oauth/agent-key", @@ -999,15 +1001,13 @@ def _mint_agent_key( if response.status_code == 200: payload = response.json() if "api_key" not in payload: - raise AuthError("Mint response missing api_key", - provider="nous", code="server_error") + raise AuthError("Mint response missing api_key", provider="nous", code="server_error") return payload try: error_payload = response.json() except Exception as exc: - raise AuthError("Agent key mint request failed", - provider="nous", code="server_error") from exc + raise AuthError("Agent key mint request failed", provider="nous", code="server_error") from exc code = str(error_payload.get("error", "server_error")) description = str(error_payload.get("error_description") or "Agent key mint request failed") @@ -1021,7 +1021,7 @@ def fetch_nous_models( api_key: str, timeout_seconds: float = 15.0, verify: bool | str = True, -) -> List[str]: +) -> list[str]: """Fetch available model IDs from the Nous inference API.""" timeout = httpx.Timeout(timeout_seconds) with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client: @@ -1044,7 +1044,7 @@ def fetch_nous_models( if not isinstance(data, list): return [] - model_ids: List[str] = [] + model_ids: list[str] = [] for item in data: if not isinstance(item, dict): continue @@ -1059,7 +1059,7 @@ def fetch_nous_models( return list(dict.fromkeys(model_ids)) -def _agent_key_is_usable(state: Dict[str, Any], min_ttl_seconds: int) -> bool: +def _agent_key_is_usable(state: dict[str, Any], min_ttl_seconds: int) -> bool: key = state.get("agent_key") if not isinstance(key, str) or not key.strip(): return False @@ -1070,10 +1070,10 @@ def resolve_nous_runtime_credentials( *, min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS, timeout_seconds: float = 15.0, - insecure: Optional[bool] = None, - ca_bundle: Optional[str] = None, + insecure: bool | None = None, + ca_bundle: str | None = None, force_mint: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Resolve Nous inference credentials for runtime use. @@ -1092,8 +1092,7 @@ def resolve_nous_runtime_credentials( state = _load_provider_state(auth_store, "nous") if not state: - raise AuthError("Hermes is not logged into Nous Portal.", - provider="nous", relogin_required=True) + raise AuthError("Hermes is not logged into Nous Portal.", provider="nous", relogin_required=True) portal_base_url = ( _optional_base_url(state.get("portal_base_url")) @@ -1143,14 +1142,14 @@ def resolve_nous_runtime_credentials( refresh_token = state.get("refresh_token") if not isinstance(access_token, str) or not access_token: - raise AuthError("No access token found for Nous Portal login.", - provider="nous", relogin_required=True) + raise AuthError("No access token found for Nous Portal login.", provider="nous", relogin_required=True) # Step 1: refresh access token if expiring if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS): if not isinstance(refresh_token, str) or not refresh_token: - raise AuthError("Session expired and no refresh token is available.", - provider="nous", relogin_required=True) + raise AuthError( + "Session expired and no refresh token is available.", provider="nous", relogin_required=True + ) _oauth_trace( "refresh_start", @@ -1159,10 +1158,12 @@ def resolve_nous_runtime_credentials( refresh_token_fp=_token_fingerprint(refresh_token), ) refreshed = _refresh_access_token( - client=client, portal_base_url=portal_base_url, - client_id=client_id, refresh_token=refresh_token, + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + refresh_token=refresh_token, ) - now = datetime.now(timezone.utc) + now = datetime.now(UTC) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) previous_refresh_token = refresh_token state["access_token"] = refreshed["access_token"] @@ -1174,9 +1175,7 @@ def resolve_nous_runtime_credentials( inference_base_url = refreshed_url state["obtained_at"] = now.isoformat() state["expires_in"] = access_ttl - state["expires_at"] = datetime.fromtimestamp( - now.timestamp() + access_ttl, tz=timezone.utc - ).isoformat() + state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat() access_token = state["access_token"] refresh_token = state["refresh_token"] _oauth_trace( @@ -1191,7 +1190,7 @@ def resolve_nous_runtime_credentials( # Step 2: mint agent key if missing/expiring used_cached_key = False - mint_payload: Optional[Dict[str, Any]] = None + mint_payload: dict[str, Any] | None = None if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds): used_cached_key = True @@ -1204,8 +1203,10 @@ def resolve_nous_runtime_credentials( access_token_fp=_token_fingerprint(access_token), ) mint_payload = _mint_agent_key( - client=client, portal_base_url=portal_base_url, - access_token=access_token, min_ttl_seconds=min_key_ttl_seconds, + client=client, + portal_base_url=portal_base_url, + access_token=access_token, + min_ttl_seconds=min_key_ttl_seconds, ) except AuthError as exc: _oauth_trace( @@ -1227,10 +1228,12 @@ def resolve_nous_runtime_credentials( refresh_token_fp=_token_fingerprint(latest_refresh_token), ) refreshed = _refresh_access_token( - client=client, portal_base_url=portal_base_url, - client_id=client_id, refresh_token=latest_refresh_token, + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + refresh_token=latest_refresh_token, ) - now = datetime.now(timezone.utc) + now = datetime.now(UTC) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) state["access_token"] = refreshed["access_token"] state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token @@ -1241,9 +1244,7 @@ def resolve_nous_runtime_credentials( inference_base_url = refreshed_url state["obtained_at"] = now.isoformat() state["expires_in"] = access_ttl - state["expires_at"] = datetime.fromtimestamp( - now.timestamp() + access_ttl, tz=timezone.utc - ).isoformat() + state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat() access_token = state["access_token"] refresh_token = state["refresh_token"] _oauth_trace( @@ -1257,14 +1258,16 @@ def resolve_nous_runtime_credentials( _persist_state("post_refresh_mint_retry") mint_payload = _mint_agent_key( - client=client, portal_base_url=portal_base_url, - access_token=access_token, min_ttl_seconds=min_key_ttl_seconds, + client=client, + portal_base_url=portal_base_url, + access_token=access_token, + min_ttl_seconds=min_key_ttl_seconds, ) else: raise if mint_payload is not None: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) state["agent_key"] = mint_payload.get("api_key") state["agent_key_id"] = mint_payload.get("key_id") state["agent_key_expires_at"] = mint_payload.get("expires_at") @@ -1293,8 +1296,7 @@ def resolve_nous_runtime_credentials( api_key = state.get("agent_key") if not isinstance(api_key, str) or not api_key: - raise AuthError("Failed to resolve a Nous inference API key", - provider="nous", code="server_error") + raise AuthError("Failed to resolve a Nous inference API key", provider="nous", code="server_error") expires_at = state.get("agent_key_expires_at") expires_epoch = _parse_iso_timestamp(expires_at) @@ -1319,7 +1321,8 @@ def resolve_nous_runtime_credentials( # Status helpers # ============================================================================= -def get_nous_auth_status() -> Dict[str, Any]: + +def get_nous_auth_status() -> dict[str, Any]: """Status snapshot for `hermes status` output.""" state = get_provider_auth_state("nous") if not state: @@ -1341,7 +1344,7 @@ def get_nous_auth_status() -> Dict[str, Any]: } -def get_codex_auth_status() -> Dict[str, Any]: +def get_codex_auth_status() -> dict[str, Any]: """Status snapshot for Codex auth.""" try: creds = resolve_codex_runtime_credentials() @@ -1360,7 +1363,7 @@ def get_codex_auth_status() -> Dict[str, Any]: } -def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]: +def get_api_key_provider_status(provider_id: str) -> dict[str, Any]: """Status snapshot for API-key providers (z.ai, Kimi, MiniMax).""" pconfig = PROVIDER_REGISTRY.get(provider_id) if not pconfig or pconfig.auth_type != "api_key": @@ -1396,7 +1399,7 @@ def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]: } -def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]: +def get_auth_status(provider_id: str | None = None) -> dict[str, Any]: """Generic auth status dispatcher.""" target = provider_id or get_active_provider() if target == "nous": @@ -1410,7 +1413,7 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]: return {"logged_in": False} -def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]: +def resolve_api_key_provider_credentials(provider_id: str) -> dict[str, Any]: """Resolve API key and base URL for an API-key provider. Returns dict with: provider, api_key, base_url, source. @@ -1455,7 +1458,8 @@ def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]: # External credential detection # ============================================================================= -def detect_external_credentials() -> List[Dict[str, Any]]: + +def detect_external_credentials() -> list[dict[str, Any]]: """Scan for credentials from other CLI tools that Hermes can reuse. Returns a list of dicts, each with: @@ -1463,17 +1467,19 @@ def detect_external_credentials() -> List[Dict[str, Any]]: - path: str -- filesystem path where creds were found - label: str -- human-friendly description for the setup UI """ - found: List[Dict[str, Any]] = [] + found: list[dict[str, Any]] = [] # Codex CLI: ~/.codex/auth.json (importable, not shared) cli_tokens = _import_codex_cli_tokens() if cli_tokens: codex_path = Path.home() / ".codex" / "auth.json" - found.append({ - "provider": "openai-codex", - "path": str(codex_path), - "label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session", - }) + found.append( + { + "provider": "openai-codex", + "path": str(codex_path), + "label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session", + } + ) return found @@ -1482,6 +1488,7 @@ def detect_external_credentials() -> List[Dict[str, Any]]: # CLI Commands — login / logout # ============================================================================= + def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Path: """Update config.yaml and auth.json to reflect the active provider.""" # Set active_provider in auth.json so auto-resolution picks this provider @@ -1494,7 +1501,7 @@ def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Pa config_path = get_config_path() config_path.parent.mkdir(parents=True, exist_ok=True) - config: Dict[str, Any] = {} + config: dict[str, Any] = {} if config_path.exists(): try: loaded = yaml.safe_load(config_path.read_text()) or {} @@ -1542,7 +1549,7 @@ def _reset_config_provider() -> Path: return config_path -def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Optional[str]: +def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> str | None: """Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None.""" # Reorder: current model first, then the rest (deduplicated) ordered = [] @@ -1564,6 +1571,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op # Try arrow-key menu first, fall back to number input try: from simple_term_menu import TerminalMenu + choices = [f" {_label(mid)}" for mid in ordered] choices.append(" Enter custom model name") choices.append(" Skip (keep current)") @@ -1621,7 +1629,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op def _save_model_choice(model_id: str) -> None: """Save the selected model to config.yaml and .env.""" - from hermes_cli.config import save_config, load_config, save_env_value + from hermes_cli.config import load_config, save_config, save_env_value config = load_config() # Handle both string and dict model formats @@ -1693,11 +1701,11 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None: config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL)) print() print("Login successful!") - print(f" Auth state: ~/.hermes/auth.json") + print(" Auth state: ~/.hermes/auth.json") print(f" Config updated: {config_path} (model.provider=openai-codex)") -def _codex_device_code_login() -> Dict[str, Any]: +def _codex_device_code_login() -> dict[str, Any]: """Run the OpenAI device code login flow and return credentials dict.""" import time as _time @@ -1715,13 +1723,15 @@ def _codex_device_code_login() -> Dict[str, Any]: except Exception as exc: raise AuthError( f"Failed to request device code: {exc}", - provider="openai-codex", code="device_code_request_failed", + provider="openai-codex", + code="device_code_request_failed", ) if resp.status_code != 200: raise AuthError( f"Device code request returned status {resp.status_code}.", - provider="openai-codex", code="device_code_request_error", + provider="openai-codex", + code="device_code_request_error", ) device_data = resp.json() @@ -1732,14 +1742,15 @@ def _codex_device_code_login() -> Dict[str, Any]: if not user_code or not device_auth_id: raise AuthError( "Device code response missing required fields.", - provider="openai-codex", code="device_code_incomplete", + provider="openai-codex", + code="device_code_incomplete", ) # Step 2: Show user the code print("To continue, follow these steps:\n") - print(f" 1. Open this URL in your browser:") + print(" 1. Open this URL in your browser:") print(f" \033[94m{issuer}/codex/device\033[0m\n") - print(f" 2. Enter this code:") + print(" 2. Enter this code:") print(f" \033[94m{user_code}\033[0m\n") print("Waiting for sign-in... (press Ctrl+C to cancel)") @@ -1766,7 +1777,8 @@ def _codex_device_code_login() -> Dict[str, Any]: else: raise AuthError( f"Device auth polling returned status {poll_resp.status_code}.", - provider="openai-codex", code="device_code_poll_error", + provider="openai-codex", + code="device_code_poll_error", ) except KeyboardInterrupt: print("\nLogin cancelled.") @@ -1775,7 +1787,8 @@ def _codex_device_code_login() -> Dict[str, Any]: if code_resp is None: raise AuthError( "Login timed out after 15 minutes.", - provider="openai-codex", code="device_code_timeout", + provider="openai-codex", + code="device_code_timeout", ) # Step 4: Exchange authorization code for tokens @@ -1786,7 +1799,8 @@ def _codex_device_code_login() -> Dict[str, Any]: if not authorization_code or not code_verifier: raise AuthError( "Device auth response missing authorization_code or code_verifier.", - provider="openai-codex", code="device_code_incomplete_exchange", + provider="openai-codex", + code="device_code_incomplete_exchange", ) try: @@ -1805,13 +1819,15 @@ def _codex_device_code_login() -> Dict[str, Any]: except Exception as exc: raise AuthError( f"Token exchange failed: {exc}", - provider="openai-codex", code="token_exchange_failed", + provider="openai-codex", + code="token_exchange_failed", ) if token_resp.status_code != 200: raise AuthError( f"Token exchange returned status {token_resp.status_code}.", - provider="openai-codex", code="token_exchange_error", + provider="openai-codex", + code="token_exchange_error", ) tokens = token_resp.json() @@ -1821,14 +1837,12 @@ def _codex_device_code_login() -> Dict[str, Any]: if not access_token: raise AuthError( "Token exchange did not return an access_token.", - provider="openai-codex", code="token_exchange_no_access_token", + provider="openai-codex", + code="token_exchange_no_access_token", ) # Return tokens for the caller to persist (no longer writes to ~/.codex/) - base_url = ( - os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") - or DEFAULT_CODEX_BASE_URL - ) + base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL return { "tokens": { @@ -1836,7 +1850,7 @@ def _codex_device_code_login() -> Dict[str, Any]: "refresh_token": refresh_token, }, "base_url": base_url, - "last_refresh": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + "last_refresh": datetime.now(UTC).isoformat().replace("+00:00", "Z"), "auth_mode": "chatgpt", "source": "device-code", } @@ -1851,9 +1865,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: or pconfig.portal_base_url ).rstrip("/") requested_inference_url = ( - getattr(args, "inference_url", None) - or os.getenv("NOUS_INFERENCE_BASE_URL") - or pconfig.inference_base_url + getattr(args, "inference_url", None) or os.getenv("NOUS_INFERENCE_BASE_URL") or pconfig.inference_base_url ).rstrip("/") client_id = getattr(args, "client_id", None) or pconfig.client_id scope = getattr(args, "scope", None) or pconfig.scope @@ -1862,11 +1874,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: timeout = httpx.Timeout(timeout_seconds) insecure = bool(getattr(args, "insecure", False)) - ca_bundle = ( - getattr(args, "ca_bundle", None) - or os.getenv("HERMES_CA_BUNDLE") - or os.getenv("SSL_CERT_FILE") - ) + ca_bundle = getattr(args, "ca_bundle", None) or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE") verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True) # Skip browser open in SSH sessions @@ -1883,8 +1891,10 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: try: with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client: device_data = _request_device_code( - client=client, portal_base_url=portal_base_url, - client_id=client_id, scope=scope, + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + scope=scope, ) verification_url = str(device_data["verification_uri_complete"]) @@ -1908,19 +1918,19 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: print(f"Waiting for approval (polling every {effective_interval}s)...") token_data = _poll_for_token( - client=client, portal_base_url=portal_base_url, - client_id=client_id, device_code=str(device_data["device_code"]), - expires_in=expires_in, poll_interval=interval, + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + device_code=str(device_data["device_code"]), + expires_in=expires_in, + poll_interval=interval, ) # Process token response - now = datetime.now(timezone.utc) + now = datetime.now(UTC) token_expires_in = _coerce_ttl_seconds(token_data.get("expires_in", 0)) expires_at = now.timestamp() + token_expires_in - inference_base_url = ( - _optional_base_url(token_data.get("inference_base_url")) - or requested_inference_url - ) + inference_base_url = _optional_base_url(token_data.get("inference_base_url")) or requested_inference_url if inference_base_url != requested_inference_url: print(f"Using portal-provided inference URL: {inference_base_url}") @@ -1933,7 +1943,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: "access_token": token_data["access_token"], "refresh_token": token_data.get("refresh_token"), "obtained_at": now.isoformat(), - "expires_at": datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(), + "expires_at": datetime.fromtimestamp(expires_at, tz=UTC).isoformat(), "expires_in": token_expires_in, "tls": { "insecure": verify is False, @@ -1964,13 +1974,13 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: runtime_creds = resolve_nous_runtime_credentials( min_key_ttl_seconds=5 * 60, timeout_seconds=timeout_seconds, - insecure=insecure, ca_bundle=ca_bundle, + insecure=insecure, + ca_bundle=ca_bundle, ) runtime_key = runtime_creds.get("api_key") runtime_base_url = runtime_creds.get("base_url") or inference_base_url if not isinstance(runtime_key, str) or not runtime_key: - raise AuthError("No runtime API key available to fetch models", - provider="nous", code="invalid_token") + raise AuthError("No runtime API key available to fetch models", provider="nous", code="invalid_token") model_ids = fetch_nous_models( inference_base_url=runtime_base_url, diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index 395a2381fb..6d8b9e5db7 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -9,14 +9,12 @@ import os import subprocess import time from pathlib import Path -from typing import Dict, List, Any, Optional - -from rich.console import Console -from rich.panel import Panel -from rich.table import Table from prompt_toolkit import print_formatted_text as _pt_print from prompt_toolkit.formatted_text import ANSI as _PT_ANSI +from rich.console import Console +from rich.panel import Panel +from rich.table import Table logger = logging.getLogger(__name__) @@ -77,7 +75,8 @@ COMPACT_BANNER = """ # Skills scanning # ========================================================================= -def get_available_skills() -> Dict[str, List[str]]: + +def get_available_skills() -> dict[str, list[str]]: """Scan ~/.hermes/skills/ and return skills grouped by category.""" import os @@ -110,7 +109,7 @@ def get_available_skills() -> Dict[str, List[str]]: _UPDATE_CHECK_CACHE_SECONDS = 6 * 3600 -def check_for_updates() -> Optional[int]: +def check_for_updates() -> int | None: """Check how many commits behind origin/main the local repo is. Does a ``git fetch`` at most once every 6 hours (cached to @@ -139,7 +138,8 @@ def check_for_updates() -> Optional[int]: try: subprocess.run( ["git", "fetch", "origin", "--quiet"], - capture_output=True, timeout=10, + capture_output=True, + timeout=10, cwd=str(repo_dir), ) except Exception: @@ -149,7 +149,9 @@ def check_for_updates() -> Optional[int]: try: result = subprocess.run( ["git", "rev-list", "--count", "HEAD..origin/main"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, cwd=str(repo_dir), ) if result.returncode == 0: @@ -172,6 +174,7 @@ def check_for_updates() -> Optional[int]: # Welcome banner # ========================================================================= + def _format_context_length(tokens: int) -> str: """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" if tokens >= 1_000_000: @@ -183,12 +186,16 @@ def _format_context_length(tokens: int) -> str: return str(tokens) -def build_welcome_banner(console: Console, model: str, cwd: str, - tools: List[dict] = None, - enabled_toolsets: List[str] = None, - session_id: str = None, - get_toolset_for_tool=None, - context_length: int = None): +def build_welcome_banner( + console: Console, + model: str, + cwd: str, + tools: list[dict] = None, + enabled_toolsets: list[str] = None, + session_id: str = None, + get_toolset_for_tool=None, + context_length: int = None, +): """Build and print a welcome banner with caduceus on left and info on right. Args: @@ -201,7 +208,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, get_toolset_for_tool: Callable to map tool name -> toolset name. context_length: Model's context window size in tokens. """ - from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS + from model_tools import check_tool_availability + if get_toolset_for_tool is None: from model_tools import get_toolset_for_tool @@ -221,7 +229,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str, model_short = model.split("/")[-1] if "/" in model else model if len(model_short) > 28: model_short = model_short[:25] + "..." - ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + ctx_str = ( + f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + ) left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") if session_id: @@ -229,7 +239,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, left_content = "\n".join(left_lines) right_lines = ["[bold #FFBF00]Available Tools[/]"] - toolsets_dict: Dict[str, list] = {} + toolsets_dict: dict[str, list] = {} for tool in tools: tool_name = tool["function"]["name"] @@ -286,6 +296,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, # MCP Servers section (only if configured) try: from tools.mcp_tool import get_mcp_status + mcp_status = get_mcp_status() except Exception: mcp_status = [] @@ -300,10 +311,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]" ) else: - right_lines.append( - f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] " - f"[red]— failed[/]" - ) + right_lines.append(f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] [red]— failed[/]") right_lines.append("") right_lines.append("[bold #FFBF00]Available Skills[/]") diff --git a/hermes_cli/callbacks.py b/hermes_cli/callbacks.py index bfce9c0010..c970b382ba 100644 --- a/hermes_cli/callbacks.py +++ b/hermes_cli/callbacks.py @@ -9,7 +9,7 @@ with the TUI. import queue import time as _time -from hermes_cli.banner import cprint, _DIM, _RST +from hermes_cli.banner import _DIM, _RST, cprint def clarify_callback(cli, question, choices): @@ -33,7 +33,7 @@ def clarify_callback(cli, question, choices): cli._clarify_deadline = _time.monotonic() + timeout cli._clarify_freetext = is_open_ended - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() while True: @@ -45,13 +45,13 @@ def clarify_callback(cli, question, choices): remaining = cli._clarify_deadline - _time.monotonic() if remaining <= 0: break - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cli._clarify_state = None cli._clarify_freetext = False cli._clarify_deadline = 0 - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cprint(f"\n{_DIM}(clarify timed out after {timeout}s — agent will decide){_RST}") return ( @@ -71,7 +71,7 @@ def sudo_password_callback(cli) -> str: cli._sudo_state = {"response_queue": response_queue} cli._sudo_deadline = _time.monotonic() + timeout - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() while True: @@ -79,7 +79,7 @@ def sudo_password_callback(cli) -> str: result = response_queue.get(timeout=1) cli._sudo_state = None cli._sudo_deadline = 0 - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() if result: cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}") @@ -90,12 +90,12 @@ def sudo_password_callback(cli) -> str: remaining = cli._sudo_deadline - _time.monotonic() if remaining <= 0: break - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cli._sudo_state = None cli._sudo_deadline = 0 - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}") return "" @@ -119,7 +119,7 @@ def approval_callback(cli, command: str, description: str) -> str: } cli._approval_deadline = _time.monotonic() + timeout - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() while True: @@ -127,19 +127,19 @@ def approval_callback(cli, command: str, description: str) -> str: result = response_queue.get(timeout=1) cli._approval_state = None cli._approval_deadline = 0 - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() return result except queue.Empty: remaining = cli._approval_deadline - _time.monotonic() if remaining <= 0: break - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cli._approval_state = None cli._approval_deadline = 0 - if hasattr(cli, '_app') and cli._app: + if hasattr(cli, "_app") and cli._app: cli._app.invalidate() cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}") return "deny" diff --git a/hermes_cli/clipboard.py b/hermes_cli/clipboard.py index 6373cfc8b3..1cbd9288e2 100644 --- a/hermes_cli/clipboard.py +++ b/hermes_cli/clipboard.py @@ -51,6 +51,7 @@ def has_clipboard_image() -> bool: # ── macOS ──────────────────────────────────────────────────────────────── + def _macos_save(dest: Path) -> bool: """Try pngpaste first (fast, handles more formats), fall back to osascript.""" return _macos_pngpaste(dest) or _macos_osascript(dest) @@ -61,7 +62,9 @@ def _macos_has_image() -> bool: try: info = subprocess.run( ["osascript", "-e", "clipboard info"], - capture_output=True, text=True, timeout=3, + capture_output=True, + text=True, + timeout=3, ) return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout except Exception: @@ -73,7 +76,8 @@ def _macos_pngpaste(dest: Path) -> bool: try: r = subprocess.run( ["pngpaste", str(dest)], - capture_output=True, timeout=3, + capture_output=True, + timeout=3, ) if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0: return True @@ -91,19 +95,21 @@ def _macos_osascript(dest: Path) -> bool: # Extract as PNG script = ( - 'try\n' - ' set imgData to the clipboard as «class PNGf»\n' + "try\n" + " set imgData to the clipboard as «class PNGf»\n" f' set f to open for access POSIX file "{dest}" with write permission\n' - ' write imgData to f\n' - ' close access f\n' - 'on error\n' + " write imgData to f\n" + " close access f\n" + "on error\n" ' return "fail"\n' - 'end try\n' + "end try\n" ) try: r = subprocess.run( ["osascript", "-e", script], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0: return True @@ -114,13 +120,14 @@ def _macos_osascript(dest: Path) -> bool: # ── Linux ──────────────────────────────────────────────────────────────── + def _is_wsl() -> bool: """Detect if running inside WSL (1 or 2).""" global _wsl_detected if _wsl_detected is not None: return _wsl_detected try: - with open("/proc/version", "r") as f: + with open("/proc/version") as f: _wsl_detected = "microsoft" in f.read().lower() except Exception: _wsl_detected = False @@ -145,10 +152,7 @@ def _linux_save(dest: Path) -> bool: # PowerShell script: get clipboard image as base64-encoded PNG on stdout. # Using .NET System.Windows.Forms.Clipboard — always available on Windows. -_PS_CHECK_IMAGE = ( - "Add-Type -AssemblyName System.Windows.Forms;" - "[System.Windows.Forms.Clipboard]::ContainsImage()" -) +_PS_CHECK_IMAGE = "Add-Type -AssemblyName System.Windows.Forms;[System.Windows.Forms.Clipboard]::ContainsImage()" _PS_EXTRACT_IMAGE = ( "Add-Type -AssemblyName System.Windows.Forms;" @@ -165,9 +169,10 @@ def _wsl_has_image() -> bool: """Check if Windows clipboard has an image (via powershell.exe).""" try: r = subprocess.run( - ["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", - _PS_CHECK_IMAGE], - capture_output=True, text=True, timeout=8, + ["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE], + capture_output=True, + text=True, + timeout=8, ) return r.returncode == 0 and "True" in r.stdout except FileNotFoundError: @@ -181,9 +186,10 @@ def _wsl_save(dest: Path) -> bool: """Extract clipboard image via powershell.exe → base64 → decode to PNG.""" try: r = subprocess.run( - ["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", - _PS_EXTRACT_IMAGE], - capture_output=True, text=True, timeout=15, + ["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE], + capture_output=True, + text=True, + timeout=15, ) if r.returncode != 0: return False @@ -206,16 +212,17 @@ def _wsl_save(dest: Path) -> bool: # ── Wayland (wl-paste) ────────────────────────────────────────────────── + def _wayland_has_image() -> bool: """Check if Wayland clipboard has image content.""" try: r = subprocess.run( ["wl-paste", "--list-types"], - capture_output=True, text=True, timeout=3, - ) - return r.returncode == 0 and any( - t.startswith("image/") for t in r.stdout.splitlines() + capture_output=True, + text=True, + timeout=3, ) + return r.returncode == 0 and any(t.startswith("image/") for t in r.stdout.splitlines()) except FileNotFoundError: logger.debug("wl-paste not installed — Wayland clipboard unavailable") except Exception: @@ -229,7 +236,9 @@ def _wayland_save(dest: Path) -> bool: # Check available MIME types types_r = subprocess.run( ["wl-paste", "--list-types"], - capture_output=True, text=True, timeout=3, + capture_output=True, + text=True, + timeout=3, ) if types_r.returncode != 0: return False @@ -237,8 +246,7 @@ def _wayland_save(dest: Path) -> bool: # Prefer PNG, fall back to other image formats mime = None - for preferred in ("image/png", "image/jpeg", "image/bmp", - "image/gif", "image/webp"): + for preferred in ("image/png", "image/jpeg", "image/bmp", "image/gif", "image/webp"): if preferred in types: mime = preferred break @@ -250,7 +258,10 @@ def _wayland_save(dest: Path) -> bool: with open(dest, "wb") as f: subprocess.run( ["wl-paste", "--type", mime], - stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True, + stdout=f, + stderr=subprocess.DEVNULL, + timeout=5, + check=True, ) if not dest.exists() or dest.stat().st_size == 0: @@ -276,6 +287,7 @@ def _convert_to_png(path: Path) -> bool: # Try Pillow first (likely installed in the venv) try: from PIL import Image + img = Image.open(path) img.save(path, "PNG") return True @@ -290,7 +302,8 @@ def _convert_to_png(path: Path) -> bool: path.rename(tmp) r = subprocess.run( ["convert", str(tmp), "png:" + str(path)], - capture_output=True, timeout=5, + capture_output=True, + timeout=5, ) tmp.unlink(missing_ok=True) if r.returncode == 0 and path.exists() and path.stat().st_size > 0: @@ -310,12 +323,15 @@ def _convert_to_png(path: Path) -> bool: # ── X11 (xclip) ───────────────────────────────────────────────────────── + def _xclip_has_image() -> bool: """Check if X11 clipboard has image content.""" try: r = subprocess.run( ["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"], - capture_output=True, text=True, timeout=3, + capture_output=True, + text=True, + timeout=3, ) return r.returncode == 0 and "image/png" in r.stdout except FileNotFoundError: @@ -331,7 +347,9 @@ def _xclip_save(dest: Path) -> bool: try: targets = subprocess.run( ["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"], - capture_output=True, text=True, timeout=3, + capture_output=True, + text=True, + timeout=3, ) if "image/png" not in targets.stdout: return False @@ -346,7 +364,10 @@ def _xclip_save(dest: Path) -> bool: with open(dest, "wb") as f: subprocess.run( ["xclip", "-selection", "clipboard", "-t", "image/png", "-o"], - stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True, + stdout=f, + stderr=subprocess.DEVNULL, + timeout=5, + check=True, ) if dest.exists() and dest.stat().st_size > 0: return True diff --git a/hermes_cli/codex_models.py b/hermes_cli/codex_models.py index bc7e8525ea..a02f8595c1 100644 --- a/hermes_cli/codex_models.py +++ b/hermes_cli/codex_models.py @@ -4,14 +4,12 @@ from __future__ import annotations import json import logging -from pathlib import Path -from typing import List, Optional - import os +from pathlib import Path logger = logging.getLogger(__name__) -DEFAULT_CODEX_MODELS: List[str] = [ +DEFAULT_CODEX_MODELS: list[str] = [ "gpt-5.3-codex", "gpt-5.2-codex", "gpt-5.1-codex-max", @@ -19,10 +17,11 @@ DEFAULT_CODEX_MODELS: List[str] = [ ] -def _fetch_models_from_api(access_token: str) -> List[str]: +def _fetch_models_from_api(access_token: str) -> list[str]: """Fetch available models from the Codex API. Returns visible models sorted by priority.""" try: import httpx + resp = httpx.get( "https://chatgpt.com/backend-api/codex/models?client_version=1.0.0", headers={"Authorization": f"Bearer {access_token}"}, @@ -57,7 +56,7 @@ def _fetch_models_from_api(access_token: str) -> List[str]: return [slug for _, slug in sortable] -def _read_default_model(codex_home: Path) -> Optional[str]: +def _read_default_model(codex_home: Path) -> str | None: config_path = codex_home / "config.toml" if not config_path.exists(): return None @@ -75,7 +74,7 @@ def _read_default_model(codex_home: Path) -> Optional[str]: return None -def _read_cache_models(codex_home: Path) -> List[str]: +def _read_cache_models(codex_home: Path) -> list[str]: cache_path = codex_home / "models_cache.json" if not cache_path.exists(): return [] @@ -104,22 +103,22 @@ def _read_cache_models(codex_home: Path) -> List[str]: sortable.append((rank, slug)) sortable.sort(key=lambda item: (item[0], item[1])) - deduped: List[str] = [] + deduped: list[str] = [] for _, slug in sortable: if slug not in deduped: deduped.append(slug) return deduped -def get_codex_model_ids(access_token: Optional[str] = None) -> List[str]: +def get_codex_model_ids(access_token: str | None = None) -> list[str]: """Return available Codex model IDs, trying API first, then local sources. - + Resolution order: API (live, if token provided) > config.toml default > local cache > hardcoded defaults. """ codex_home_str = os.getenv("CODEX_HOME", "").strip() or str(Path.home() / ".codex") codex_home = Path(codex_home_str).expanduser() - ordered: List[str] = [] + ordered: list[str] = [] # Try live API if we have a token if access_token: diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 20f01b1748..9bba2edd6a 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -12,7 +12,6 @@ from typing import Any from prompt_toolkit.completion import Completer, Completion - COMMANDS = { "/help": "Show this help message", "/tools": "List available tools", diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 018ac6557f..0fde6288a3 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -14,10 +14,10 @@ This module provides: import os import platform -import sys import subprocess +import sys from pathlib import Path -from typing import Dict, Any, Optional, List, Tuple +from typing import Any _IS_WINDOWS = platform.system() == "Windows" @@ -25,27 +25,31 @@ import yaml from hermes_cli.colors import Colors, color - # ============================================================================= # Config paths # ============================================================================= + def get_hermes_home() -> Path: """Get the Hermes home directory (~/.hermes).""" return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) + def get_config_path() -> Path: """Get the main config file path.""" return get_hermes_home() / "config.yaml" + def get_env_path() -> Path: """Get the .env file path (for API keys).""" return get_hermes_home() / ".env" + def get_project_root() -> Path: """Get the project installation directory.""" return Path(__file__).parent.parent.resolve() + def ensure_hermes_home(): """Ensure ~/.hermes directory structure exists.""" home = get_hermes_home() @@ -63,7 +67,6 @@ DEFAULT_CONFIG = { "model": "anthropic/claude-opus-4.6", "toolsets": ["hermes-cli"], "max_turns": 100, - "terminal": { "backend": "local", "cwd": ".", # Use current directory @@ -74,47 +77,42 @@ DEFAULT_CONFIG = { "daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20", # Container resource limits (docker, singularity, modal, daytona — ignored for local/ssh) "container_cpu": 1, - "container_memory": 5120, # MB (default 5GB) - "container_disk": 51200, # MB (default 50GB) - "container_persistent": True, # Persist filesystem across sessions + "container_memory": 5120, # MB (default 5GB) + "container_disk": 51200, # MB (default 50GB) + "container_persistent": True, # Persist filesystem across sessions # Docker volume mounts — share host directories with the container. # Each entry is "host_path:container_path" (standard Docker -v syntax). # Example: ["/home/user/projects:/workspace/projects", "/data:/data"] "docker_volumes": [], }, - "browser": { "inactivity_timeout": 120, "record_sessions": False, # Auto-record browser sessions as WebM videos }, - "compression": { "enabled": True, "threshold": 0.85, "summary_model": "google/gemini-3-flash-preview", "summary_provider": "auto", }, - # Auxiliary model overrides (advanced). By default Hermes auto-selects # the provider and model for each side task. Set these to override. "auxiliary": { "vision": { - "provider": "auto", # auto | openrouter | nous | main - "model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o" + "provider": "auto", # auto | openrouter | nous | main + "model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o" }, "web_extract": { "provider": "auto", "model": "", }, }, - "display": { "compact": False, "personality": "kawaii", "resume_display": "full", # "full" (show previous messages) | "minimal" (one-liner only) "bell_on_complete": False, # Play terminal bell (\a) when agent finishes a response }, - # Text-to-speech configuration "tts": { "provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" @@ -132,43 +130,35 @@ DEFAULT_CONFIG = { # Voices: alloy, echo, fable, onyx, nova, shimmer }, }, - "stt": { "enabled": True, "model": "whisper-1", }, - "human_delay": { "mode": "off", "min_ms": 800, "max_ms": 2500, }, - # Persistent memory -- bounded curated memory injected into system prompt "memory": { "memory_enabled": True, "user_profile_enabled": True, - "memory_char_limit": 2200, # ~800 tokens at 2.75 chars/token - "user_char_limit": 1375, # ~500 tokens at 2.75 chars/token + "memory_char_limit": 2200, # ~800 tokens at 2.75 chars/token + "user_char_limit": 1375, # ~500 tokens at 2.75 chars/token }, - # Ephemeral prefill messages file — JSON list of {role, content} dicts # injected at the start of every API call for few-shot priming. # Never saved to sessions, logs, or trajectories. "prefill_messages_file": "", - # Honcho AI-native memory -- reads ~/.honcho/config.json as single source of truth. # This section is only needed for hermes-specific overrides; everything else # (apiKey, workspace, peerName, sessions, enabled) comes from the global config. "honcho": {}, - # IANA timezone (e.g. "Asia/Kolkata", "America/New_York"). # Empty string means use server-local time. "timezone": "", - # Permanently allowed dangerous command patterns (added via "always" approval) "command_allowlist": [], - # Config schema version - bump this when adding new required fields "_config_version": 5, } @@ -179,11 +169,17 @@ DEFAULT_CONFIG = { # Track which env vars were introduced in each config version. # Migration only mentions vars new since the user's previous version. -ENV_VARS_BY_VERSION: Dict[int, List[str]] = { +ENV_VARS_BY_VERSION: dict[int, list[str]] = { 3: ["FIRECRAWL_API_KEY", "BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "FAL_KEY"], 4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"], - 5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS", - "SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"], + 5: [ + "WHATSAPP_ENABLED", + "WHATSAPP_MODE", + "WHATSAPP_ALLOWED_USERS", + "SLACK_BOT_TOKEN", + "SLACK_APP_TOKEN", + "SLACK_ALLOWED_USERS", + ], } # Required environment variables with metadata for migration prompts. @@ -284,7 +280,6 @@ OPTIONAL_ENV_VARS = { "category": "provider", "advanced": True, }, - # ── Tool API keys ── "FIRECRAWL_API_KEY": { "description": "Firecrawl API key for web search and scraping", @@ -364,7 +359,6 @@ OPTIONAL_ENV_VARS = { "password": True, "category": "tool", }, - # ── Honcho ── "HONCHO_API_KEY": { "description": "Honcho API key for AI-native persistent memory", @@ -374,7 +368,6 @@ OPTIONAL_ENV_VARS = { "password": True, "category": "tool", }, - # ── Messaging platforms ── "TELEGRAM_BOT_TOKEN": { "description": "Telegram bot token from @BotFather", @@ -406,8 +399,8 @@ OPTIONAL_ENV_VARS = { }, "SLACK_BOT_TOKEN": { "description": "Slack bot token (xoxb-). Get from OAuth & Permissions after installing your app. " - "Required scopes: chat:write, app_mentions:read, channels:history, groups:history, " - "im:history, im:read, im:write, users:read, files:write", + "Required scopes: chat:write, app_mentions:read, channels:history, groups:history, " + "im:history, im:read, im:write, users:read, files:write", "prompt": "Slack Bot Token (xoxb-...)", "url": "https://api.slack.com/apps", "password": True, @@ -415,8 +408,8 @@ OPTIONAL_ENV_VARS = { }, "SLACK_APP_TOKEN": { "description": "Slack app-level token (xapp-) for Socket Mode. Get from Basic Information → " - "App-Level Tokens. Also ensure Event Subscriptions include: message.im, " - "message.channels, message.groups, app_mention", + "App-Level Tokens. Also ensure Event Subscriptions include: message.im, " + "message.channels, message.groups, app_mention", "prompt": "Slack App Token (xapp-...)", "url": "https://api.slack.com/apps", "password": True, @@ -430,7 +423,6 @@ OPTIONAL_ENV_VARS = { "category": "messaging", "advanced": True, }, - # ── Agent settings ── "MESSAGING_CWD": { "description": "Working directory for terminal commands via messaging", @@ -487,25 +479,25 @@ OPTIONAL_ENV_VARS = { } -def get_missing_env_vars(required_only: bool = False) -> List[Dict[str, Any]]: +def get_missing_env_vars(required_only: bool = False) -> list[dict[str, Any]]: """ Check which environment variables are missing. - + Returns list of dicts with var info for missing variables. """ missing = [] - + # Check required vars for var_name, info in REQUIRED_ENV_VARS.items(): if not get_env_value(var_name): missing.append({"name": var_name, **info, "is_required": True}) - + # Check optional vars (if not required_only) if not required_only: for var_name, info in OPTIONAL_ENV_VARS.items(): if not get_env_value(var_name): missing.append({"name": var_name, **info, "is_required": False}) - + return missing @@ -524,10 +516,10 @@ def _set_nested(config: dict, dotted_key: str, value): current[parts[-1]] = value -def get_missing_config_fields() -> List[Dict[str, Any]]: +def get_missing_config_fields() -> list[dict[str, Any]]: """ Check which config fields are missing or outdated (recursive). - + Walks the DEFAULT_CONFIG tree at arbitrary depth and reports any keys present in defaults but absent from the user's loaded config. """ @@ -536,15 +528,17 @@ def get_missing_config_fields() -> List[Dict[str, Any]]: def _check(defaults: dict, current: dict, prefix: str = ""): for key, default_value in defaults.items(): - if key.startswith('_'): + if key.startswith("_"): continue full_key = key if not prefix else f"{prefix}.{key}" if key not in current: - missing.append({ - "key": full_key, - "default": default_value, - "description": f"New config option: {full_key}", - }) + missing.append( + { + "key": full_key, + "default": default_value, + "description": f"New config option: {full_key}", + } + ) elif isinstance(default_value, dict) and isinstance(current.get(key), dict): _check(default_value, current[key], full_key) @@ -552,10 +546,10 @@ def get_missing_config_fields() -> List[Dict[str, Any]]: return missing -def check_config_version() -> Tuple[int, int]: +def check_config_version() -> tuple[int, int]: """ Check config version. - + Returns (current_version, latest_version). """ config = load_config() @@ -564,22 +558,22 @@ def check_config_version() -> Tuple[int, int]: return current, latest -def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, Any]: +def migrate_config(interactive: bool = True, quiet: bool = False) -> dict[str, Any]: """ Migrate config to latest version, prompting for new required fields. - + Args: interactive: If True, prompt user for missing values quiet: If True, suppress output - + Returns: Dict with migration results: {"env_added": [...], "config_added": [...], "warnings": [...]} """ results = {"env_added": [], "config_added": [], "warnings": []} - + # Check config version current_ver, latest_ver = check_config_version() - + # ── Version 3 → 4: migrate tool progress from .env to config.yaml ── if current_ver < 4: config = load_config() @@ -594,7 +588,9 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A results["config_added"].append("display.tool_progress=off (from HERMES_TOOL_PROGRESS=false)") elif old_mode and old_mode.lower() in ("new", "all"): display["tool_progress"] = old_mode.lower() - results["config_added"].append(f"display.tool_progress={old_mode.lower()} (from HERMES_TOOL_PROGRESS_MODE)") + results["config_added"].append( + f"display.tool_progress={old_mode.lower()} (from HERMES_TOOL_PROGRESS_MODE)" + ) else: display["tool_progress"] = "all" results["config_added"].append("display.tool_progress=all (default)") @@ -602,7 +598,7 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A save_config(config) if not quiet: print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}") - + # ── Version 4 → 5: add timezone field ── if current_ver < 5: config = load_config() @@ -621,27 +617,28 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") - + # Check for missing required env vars missing_env = get_missing_env_vars(required_only=True) - + if missing_env and not quiet: print("\n⚠️ Missing required environment variables:") for var in missing_env: print(f" • {var['name']}: {var['description']}") - + if interactive and missing_env: print("\nLet's configure them now:\n") for var in missing_env: if var.get("url"): print(f" Get your key at: {var['url']}") - + if var.get("password"): import getpass + value = getpass.getpass(f" {var['prompt']}: ") else: value = input(f" {var['prompt']}: ").strip() - + if value: save_env_value(var["name"], value) results["env_added"].append(var["name"]) @@ -649,16 +646,13 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A else: results["warnings"].append(f"Skipped {var['name']} - some features may not work") print() - + # Check for missing optional env vars and offer to configure interactively # Skip "advanced" vars (like OPENAI_BASE_URL) -- those are for power users missing_optional = get_missing_env_vars(required_only=False) required_names = {v["name"] for v in missing_env} if missing_env else set() - missing_optional = [ - v for v in missing_optional - if v["name"] not in required_names and not v.get("advanced") - ] - + missing_optional = [v for v in missing_optional if v["name"] not in required_names and not v.get("advanced")] + # Only offer to configure env vars that are NEW since the user's previous version new_var_names = set() for ver in range(current_ver + 1, latest_ver + 1): @@ -690,6 +684,7 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A print(f" {info.get('description', name)}") if info.get("password"): import getpass + value = getpass.getpass(f" {info.get('prompt', name)} (Enter to skip): ") else: value = input(f" {info.get('prompt', name)} (Enter to skip): ").strip() @@ -700,22 +695,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A print() else: print(" Set later with: hermes config set KEY VALUE") - + # Check for missing config fields missing_config = get_missing_config_fields() - + if missing_config: config = load_config() - + for field in missing_config: key = field["key"] default = field["default"] - + _set_nested(config, key, default) results["config_added"].append(key) if not quiet: print(f" ✓ Added {key} = {default}") - + # Update version and save config["_config_version"] = latest_ver save_config(config) @@ -724,7 +719,7 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A config = load_config() config["_config_version"] = latest_ver save_config(config) - + return results @@ -737,33 +732,30 @@ def _deep_merge(base: dict, override: dict) -> dict: """ result = base.copy() for key, value in override.items(): - if ( - key in result - and isinstance(result[key], dict) - and isinstance(value, dict) - ): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = _deep_merge(result[key], value) else: result[key] = value return result -def load_config() -> Dict[str, Any]: +def load_config() -> dict[str, Any]: """Load configuration from ~/.hermes/config.yaml.""" import copy + config_path = get_config_path() - + config = copy.deepcopy(DEFAULT_CONFIG) - + if config_path.exists(): try: with open(config_path) as f: user_config = yaml.safe_load(f) or {} - + config = _deep_merge(config, user_config) except Exception as e: print(f"Warning: Failed to load config: {e}") - + return config @@ -797,12 +789,12 @@ _COMMENTED_SECTIONS = """ """ -def save_config(config: Dict[str, Any]): +def save_config(config: dict[str, Any]): """Save configuration to ~/.hermes/config.yaml.""" ensure_hermes_home() config_path = get_config_path() - - with open(config_path, 'w') as f: + + with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) # Append commented-out sections for features that are off by default # or only relevant when explicitly configured. Skip sections the @@ -818,11 +810,11 @@ def save_config(config: Dict[str, Any]): f.write(_COMMENTED_SECTIONS) -def load_env() -> Dict[str, str]: +def load_env() -> dict[str, str]: """Load environment variables from ~/.hermes/.env.""" env_path = get_env_path() env_vars = {} - + if env_path.exists(): # On Windows, open() defaults to the system locale (cp1252) which can # fail on UTF-8 .env files. Use explicit UTF-8 only on Windows. @@ -830,10 +822,10 @@ def load_env() -> Dict[str, str]: with open(env_path, **open_kw) as f: for line in f: line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, _, value = line.partition('=') - env_vars[key.strip()] = value.strip().strip('"\'') - + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + env_vars[key.strip()] = value.strip().strip("\"'") + return env_vars @@ -841,7 +833,7 @@ def save_env_value(key: str, value: str): """Save or update a value in ~/.hermes/.env.""" ensure_hermes_home() env_path = get_env_path() - + # On Windows, open() defaults to the system locale (cp1252) which can # cause OSError errno 22 on UTF-8 .env files. read_kw = {"encoding": "utf-8", "errors": "replace"} if _IS_WINDOWS else {} @@ -851,7 +843,7 @@ def save_env_value(key: str, value: str): if env_path.exists(): with open(env_path, **read_kw) as f: lines = f.readlines() - + # Find and update or append found = False for i, line in enumerate(lines): @@ -859,23 +851,23 @@ def save_env_value(key: str, value: str): lines[i] = f"{key}={value}\n" found = True break - + if not found: # Ensure there's a newline at the end of the file before appending if lines and not lines[-1].endswith("\n"): lines[-1] += "\n" lines.append(f"{key}={value}\n") - - with open(env_path, 'w', **write_kw) as f: + + with open(env_path, "w", **write_kw) as f: f.writelines(lines) -def get_env_value(key: str) -> Optional[str]: +def get_env_value(key: str) -> str | None: """Get a value from ~/.hermes/.env or environment.""" # Check environment first if key in os.environ: return os.environ[key] - + # Then check .env file env_vars = load_env() return env_vars.get(key) @@ -885,6 +877,7 @@ def get_env_value(key: str) -> Optional[str]: # Config display # ============================================================================= + def redact_key(key: str) -> str: """Redact an API key for display.""" if not key: @@ -898,23 +891,23 @@ def show_config(): """Display current configuration.""" config = load_config() env_vars = load_env() - + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN)) print(color("│ ⚕ Hermes Configuration │", Colors.CYAN)) print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN)) - + # Paths print() print(color("◆ Paths", Colors.CYAN, Colors.BOLD)) print(f" Config: {get_config_path()}") print(f" Secrets: {get_env_path()}") print(f" Install: {get_project_root()}") - + # API Keys print() print(color("◆ API Keys", Colors.CYAN, Colors.BOLD)) - + keys = [ ("OPENROUTER_API_KEY", "OpenRouter"), ("ANTHROPIC_API_KEY", "Anthropic"), @@ -923,48 +916,48 @@ def show_config(): ("BROWSERBASE_API_KEY", "Browserbase"), ("FAL_KEY", "FAL"), ] - + for env_key, name in keys: value = get_env_value(env_key) print(f" {name:<14} {redact_key(value)}") - + # Model settings print() print(color("◆ Model", Colors.CYAN, Colors.BOLD)) print(f" Model: {config.get('model', 'not set')}") print(f" Max turns: {config.get('max_turns', 100)}") print(f" Toolsets: {', '.join(config.get('toolsets', ['all']))}") - + # Terminal print() print(color("◆ Terminal", Colors.CYAN, Colors.BOLD)) - terminal = config.get('terminal', {}) + terminal = config.get("terminal", {}) print(f" Backend: {terminal.get('backend', 'local')}") print(f" Working dir: {terminal.get('cwd', '.')}") print(f" Timeout: {terminal.get('timeout', 60)}s") - - if terminal.get('backend') == 'docker': + + if terminal.get("backend") == "docker": print(f" Docker image: {terminal.get('docker_image', 'python:3.11-slim')}") - elif terminal.get('backend') == 'singularity': + elif terminal.get("backend") == "singularity": print(f" Image: {terminal.get('singularity_image', 'docker://python:3.11')}") - elif terminal.get('backend') == 'modal': + elif terminal.get("backend") == "modal": print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}") - modal_token = get_env_value('MODAL_TOKEN_ID') + modal_token = get_env_value("MODAL_TOKEN_ID") print(f" Modal token: {'configured' if modal_token else '(not set)'}") - elif terminal.get('backend') == 'daytona': + elif terminal.get("backend") == "daytona": print(f" Daytona image: {terminal.get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}") - daytona_key = get_env_value('DAYTONA_API_KEY') + daytona_key = get_env_value("DAYTONA_API_KEY") print(f" API key: {'configured' if daytona_key else '(not set)'}") - elif terminal.get('backend') == 'ssh': - ssh_host = get_env_value('TERMINAL_SSH_HOST') - ssh_user = get_env_value('TERMINAL_SSH_USER') + elif terminal.get("backend") == "ssh": + ssh_host = get_env_value("TERMINAL_SSH_HOST") + ssh_user = get_env_value("TERMINAL_SSH_USER") print(f" SSH host: {ssh_host or '(not set)'}") print(f" SSH user: {ssh_user or '(not set)'}") - + # Timezone print() print(color("◆ Timezone", Colors.CYAN, Colors.BOLD)) - tz = config.get('timezone', '') + tz = config.get("timezone", "") if tz: print(f" Timezone: {tz}") else: @@ -973,48 +966,45 @@ def show_config(): # Compression print() print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD)) - compression = config.get('compression', {}) - enabled = compression.get('enabled', True) + compression = config.get("compression", {}) + enabled = compression.get("enabled", True) print(f" Enabled: {'yes' if enabled else 'no'}") if enabled: print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%") print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}") - comp_provider = compression.get('summary_provider', 'auto') - if comp_provider != 'auto': + comp_provider = compression.get("summary_provider", "auto") + if comp_provider != "auto": print(f" Provider: {comp_provider}") - + # Auxiliary models - auxiliary = config.get('auxiliary', {}) + auxiliary = config.get("auxiliary", {}) aux_tasks = { - "Vision": auxiliary.get('vision', {}), - "Web extract": auxiliary.get('web_extract', {}), + "Vision": auxiliary.get("vision", {}), + "Web extract": auxiliary.get("web_extract", {}), } - has_overrides = any( - t.get('provider', 'auto') != 'auto' or t.get('model', '') - for t in aux_tasks.values() - ) + has_overrides = any(t.get("provider", "auto") != "auto" or t.get("model", "") for t in aux_tasks.values()) if has_overrides: print() print(color("◆ Auxiliary Models (overrides)", Colors.CYAN, Colors.BOLD)) for label, task_cfg in aux_tasks.items(): - prov = task_cfg.get('provider', 'auto') - mdl = task_cfg.get('model', '') - if prov != 'auto' or mdl: + prov = task_cfg.get("provider", "auto") + mdl = task_cfg.get("model", "") + if prov != "auto" or mdl: parts = [f"provider={prov}"] if mdl: parts.append(f"model={mdl}") print(f" {label:12s} {', '.join(parts)}") - + # Messaging print() print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD)) - - telegram_token = get_env_value('TELEGRAM_BOT_TOKEN') - discord_token = get_env_value('DISCORD_BOT_TOKEN') - + + telegram_token = get_env_value("TELEGRAM_BOT_TOKEN") + discord_token = get_env_value("DISCORD_BOT_TOKEN") + print(f" Telegram: {'configured' if telegram_token else color('not configured', Colors.DIM)}") print(f" Discord: {'configured' if discord_token else color('not configured', Colors.DIM)}") - + print() print(color("─" * 60, Colors.DIM)) print(color(" hermes config edit # Edit config file", Colors.DIM)) @@ -1026,28 +1016,29 @@ def show_config(): def edit_config(): """Open config file in user's editor.""" config_path = get_config_path() - + # Ensure config exists if not config_path.exists(): save_config(DEFAULT_CONFIG) print(f"Created {config_path}") - + # Find editor - editor = os.getenv('EDITOR') or os.getenv('VISUAL') - + editor = os.getenv("EDITOR") or os.getenv("VISUAL") + if not editor: # Try common editors - for cmd in ['nano', 'vim', 'vi', 'code', 'notepad']: + for cmd in ["nano", "vim", "vi", "code", "notepad"]: import shutil + if shutil.which(cmd): editor = cmd break - + if not editor: - print(f"No editor found. Config file is at:") + print("No editor found. Config file is at:") print(f" {config_path}") return - + print(f"Opening {config_path} in {editor}...") subprocess.run([editor, str(config_path)]) @@ -1056,20 +1047,39 @@ def set_config_value(key: str, value: str): """Set a configuration value.""" # Check if it's an API key (goes to .env) api_keys = [ - 'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY', - 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', - 'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN', - 'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY', - 'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN', - 'GITHUB_TOKEN', 'HONCHO_API_KEY', 'WANDB_API_KEY', - 'TINKER_API_KEY', + "OPENROUTER_API_KEY", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "VOICE_TOOLS_OPENAI_KEY", + "FIRECRAWL_API_KEY", + "FIRECRAWL_API_URL", + "BROWSERBASE_API_KEY", + "BROWSERBASE_PROJECT_ID", + "FAL_KEY", + "TELEGRAM_BOT_TOKEN", + "DISCORD_BOT_TOKEN", + "TERMINAL_SSH_HOST", + "TERMINAL_SSH_USER", + "TERMINAL_SSH_KEY", + "SUDO_PASSWORD", + "SLACK_BOT_TOKEN", + "SLACK_APP_TOKEN", + "GITHUB_TOKEN", + "HONCHO_API_KEY", + "WANDB_API_KEY", + "TINKER_API_KEY", ] - - if key.upper() in api_keys or key.upper().endswith('_API_KEY') or key.upper().endswith('_TOKEN') or key.upper().startswith('TERMINAL_SSH'): + + if ( + key.upper() in api_keys + or key.upper().endswith("_API_KEY") + or key.upper().endswith("_TOKEN") + or key.upper().startswith("TERMINAL_SSH") + ): save_env_value(key.upper(), value) print(f"✓ Set {key} in {get_env_path()}") return - + # Otherwise it goes to config.yaml # Read the raw user config (not merged with defaults) to avoid # dumping all default values back to the file @@ -1081,33 +1091,33 @@ def set_config_value(key: str, value: str): user_config = yaml.safe_load(f) or {} except Exception: user_config = {} - + # Handle nested keys (e.g., "tts.provider") - parts = key.split('.') + parts = key.split(".") current = user_config - + for part in parts[:-1]: if part not in current or not isinstance(current.get(part), dict): current[part] = {} current = current[part] - + # Convert value to appropriate type - if value.lower() in ('true', 'yes', 'on'): + if value.lower() in ("true", "yes", "on"): value = True - elif value.lower() in ('false', 'no', 'off'): + elif value.lower() in ("false", "no", "off"): value = False elif value.isdigit(): value = int(value) - elif value.replace('.', '', 1).isdigit(): + elif value.replace(".", "", 1).isdigit(): value = float(value) - + current[parts[-1]] = value - + # Write only user config back (not the full merged defaults) ensure_hermes_home() - with open(config_path, 'w') as f: + with open(config_path, "w") as f: yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) - + # Keep .env in sync for keys that terminal_tool reads directly from env vars. # config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc. _config_to_env_sync = { @@ -1130,19 +1140,20 @@ def set_config_value(key: str, value: str): # Command handler # ============================================================================= + def config_command(args): """Handle config subcommands.""" - subcmd = getattr(args, 'config_command', None) - + subcmd = getattr(args, "config_command", None) + if subcmd is None or subcmd == "show": show_config() - + elif subcmd == "edit": edit_config() - + elif subcmd == "set": - key = getattr(args, 'key', None) - value = getattr(args, 'value', None) + key = getattr(args, "key", None) + value = getattr(args, "value", None) if not key or not value: print("Usage: hermes config set KEY VALUE") print() @@ -1152,81 +1163,78 @@ def config_command(args): print(" hermes config set OPENROUTER_API_KEY sk-or-...") sys.exit(1) set_config_value(key, value) - + elif subcmd == "path": print(get_config_path()) - + elif subcmd == "env-path": print(get_env_path()) - + elif subcmd == "migrate": print() print(color("🔄 Checking configuration for updates...", Colors.CYAN, Colors.BOLD)) print() - + # Check what's missing missing_env = get_missing_env_vars(required_only=False) missing_config = get_missing_config_fields() current_ver, latest_ver = check_config_version() - + if not missing_env and not missing_config and current_ver >= latest_ver: print(color("✓ Configuration is up to date!", Colors.GREEN)) print() return - + # Show what needs to be updated if current_ver < latest_ver: print(f" Config version: {current_ver} → {latest_ver}") - + if missing_config: print(f"\n {len(missing_config)} new config option(s) will be added with defaults") - + required_missing = [v for v in missing_env if v.get("is_required")] - optional_missing = [ - v for v in missing_env - if not v.get("is_required") and not v.get("advanced") - ] - + optional_missing = [v for v in missing_env if not v.get("is_required") and not v.get("advanced")] + if required_missing: print(f"\n ⚠️ {len(required_missing)} required API key(s) missing:") for var in required_missing: print(f" • {var['name']}") - + if optional_missing: print(f"\n ℹ️ {len(optional_missing)} optional API key(s) not configured:") for var in optional_missing: tools = var.get("tools", []) tools_str = f" (enables: {', '.join(tools[:2])})" if tools else "" print(f" • {var['name']}{tools_str}") - + print() - + # Run migration results = migrate_config(interactive=True, quiet=False) - + print() if results["env_added"] or results["config_added"]: print(color("✓ Configuration updated!", Colors.GREEN)) - + if results["warnings"]: print() for warning in results["warnings"]: print(color(f" ⚠️ {warning}", Colors.YELLOW)) - + print() - + elif subcmd == "check": # Non-interactive check for what's missing print() print(color("📋 Configuration Status", Colors.CYAN, Colors.BOLD)) print() - + current_ver, latest_ver = check_config_version() if current_ver >= latest_ver: print(f" Config version: {current_ver} ✓") else: print(color(f" Config version: {current_ver} → {latest_ver} (update available)", Colors.YELLOW)) - + print() print(color(" Required:", Colors.BOLD)) for var_name in REQUIRED_ENV_VARS: @@ -1234,7 +1242,7 @@ def config_command(args): print(f" ✓ {var_name}") else: print(color(f" ✗ {var_name} (missing)", Colors.RED)) - + print() print(color(" Optional:", Colors.BOLD)) for var_name, info in OPTIONAL_ENV_VARS.items(): @@ -1244,15 +1252,15 @@ def config_command(args): tools = info.get("tools", []) tools_str = f" → {', '.join(tools[:2])}" if tools else "" print(color(f" ○ {var_name}{tools_str}", Colors.DIM)) - + missing_config = get_missing_config_fields() if missing_config: print() print(color(f" {len(missing_config)} new config option(s) available", Colors.YELLOW)) - print(f" Run 'hermes config migrate' to add them") - + print(" Run 'hermes config migrate' to add them") + print() - + else: print(f"Unknown config command: {subcmd}") print() diff --git a/hermes_cli/cron.py b/hermes_cli/cron.py index b76ef5bac8..3a6c3f7c02 100644 --- a/hermes_cli/cron.py +++ b/hermes_cli/cron.py @@ -20,46 +20,46 @@ from hermes_cli.colors import Colors, color def cron_list(show_all: bool = False): """List all scheduled jobs.""" from cron.jobs import list_jobs - + jobs = list_jobs(include_disabled=show_all) - + if not jobs: print(color("No scheduled jobs.", Colors.DIM)) print(color("Create one with the /cron add command in chat, or via Telegram.", Colors.DIM)) return - + print() print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN)) print(color("│ Scheduled Jobs │", Colors.CYAN)) print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN)) print() - + for job in jobs: job_id = job.get("id", "?")[:8] name = job.get("name", "(unnamed)") schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?")) enabled = job.get("enabled", True) next_run = job.get("next_run_at", "?") - + repeat_info = job.get("repeat", {}) repeat_times = repeat_info.get("times") repeat_completed = repeat_info.get("completed", 0) - + if repeat_times: repeat_str = f"{repeat_completed}/{repeat_times}" else: repeat_str = "∞" - + deliver = job.get("deliver", ["local"]) if isinstance(deliver, str): deliver = [deliver] deliver_str = ", ".join(deliver) - + if not enabled: status = color("[disabled]", Colors.RED) else: status = color("[active]", Colors.GREEN) - + print(f" {color(job_id, Colors.YELLOW)} {status}") print(f" Name: {name}") print(f" Schedule: {schedule}") @@ -67,9 +67,10 @@ def cron_list(show_all: bool = False): print(f" Next run: {next_run}") print(f" Deliver: {deliver_str}") print() - + # Warn if gateway isn't running from hermes_cli.gateway import find_gateway_pids + if not find_gateway_pids(): print(color(" ⚠ Gateway is not running — jobs won't fire automatically.", Colors.YELLOW)) print(color(" Start it with: hermes gateway install", Colors.DIM)) @@ -79,6 +80,7 @@ def cron_list(show_all: bool = False): def cron_tick(): """Run due jobs once and exit.""" from cron.scheduler import tick + tick(verbose=True) @@ -86,9 +88,9 @@ def cron_status(): """Show cron execution status.""" from cron.jobs import list_jobs from hermes_cli.gateway import find_gateway_pids - + print() - + pids = find_gateway_pids() if pids: print(color("✓ Gateway is running — cron jobs will fire automatically", Colors.GREEN)) @@ -99,9 +101,9 @@ def cron_status(): print(" To enable automatic execution:") print(" hermes gateway install # Install as system service (recommended)") print(" hermes gateway # Or run in foreground") - + print() - + jobs = list_jobs(include_disabled=False) if jobs: next_runs = [j.get("next_run_at") for j in jobs if j.get("next_run_at")] @@ -110,24 +112,24 @@ def cron_status(): print(f" Next run: {min(next_runs)}") else: print(" No active jobs") - + print() def cron_command(args): """Handle cron subcommands.""" - subcmd = getattr(args, 'cron_command', None) - + subcmd = getattr(args, "cron_command", None) + if subcmd is None or subcmd == "list": - show_all = getattr(args, 'all', False) + show_all = getattr(args, "all", False) cron_list(show_all) - + elif subcmd == "tick": cron_tick() - + elif subcmd == "status": cron_status() - + else: print(f"Unknown cron command: {subcmd}") print("Usage: hermes cron [list|status|tick]") diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index de55bdff93..0db7f1a7b6 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -5,18 +5,18 @@ Diagnoses issues with Hermes Agent setup. """ import os -import sys -import subprocess import shutil -from pathlib import Path +import subprocess +import sys -from hermes_cli.config import get_project_root, get_hermes_home, get_env_path +from hermes_cli.config import get_env_path, get_hermes_home, get_project_root PROJECT_ROOT = get_project_root() HERMES_HOME = get_hermes_home() # Load environment variables from ~/.hermes/.env so API key checks work from dotenv import load_dotenv + _env_path = get_env_path() if _env_path.exists(): try: @@ -33,7 +33,6 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1") from hermes_cli.colors import Colors, color from hermes_constants import OPENROUTER_MODELS_URL - _PROVIDER_ENV_HINTS = ( "OPENROUTER_API_KEY", "OPENAI_API_KEY", @@ -56,35 +55,38 @@ def _has_provider_env_config(content: str) -> bool: def check_ok(text: str, detail: str = ""): print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else "")) + def check_warn(text: str, detail: str = ""): print(f" {color('⚠', Colors.YELLOW)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else "")) + def check_fail(text: str, detail: str = ""): print(f" {color('✗', Colors.RED)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else "")) + def check_info(text: str): print(f" {color('→', Colors.CYAN)} {text}") def run_doctor(args): """Run diagnostic checks.""" - should_fix = getattr(args, 'fix', False) - + should_fix = getattr(args, "fix", False) + issues = [] manual_issues = [] # issues that can't be auto-fixed fixed_count = 0 - + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN)) print(color("│ 🩺 Hermes Doctor │", Colors.CYAN)) print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN)) - + # ========================================================================= # Check: Python version # ========================================================================= print() print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD)) - + py_version = sys.version_info if py_version >= (3, 11): check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}") @@ -96,20 +98,20 @@ def run_doctor(args): else: check_fail(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ required)") issues.append("Upgrade Python to 3.10+") - + # Check if in virtual environment in_venv = sys.prefix != sys.base_prefix if in_venv: check_ok("Virtual environment active") else: check_warn("Not in virtual environment", "(recommended)") - + # ========================================================================= # Check: Required packages # ========================================================================= print() print(color("◆ Required Packages", Colors.CYAN, Colors.BOLD)) - + required_packages = [ ("openai", "OpenAI SDK"), ("rich", "Rich (terminal UI)"), @@ -117,13 +119,13 @@ def run_doctor(args): ("yaml", "PyYAML"), ("httpx", "HTTPX"), ] - + optional_packages = [ ("croniter", "Croniter (cron expressions)"), ("telegram", "python-telegram-bot"), ("discord", "discord.py"), ] - + for module, name in required_packages: try: __import__(module) @@ -131,25 +133,25 @@ def run_doctor(args): except ImportError: check_fail(name, "(missing)") issues.append(f"Install {name}: uv pip install {module}") - + for module, name in optional_packages: try: __import__(module) check_ok(name, "(optional)") except ImportError: check_warn(name, "(optional, not installed)") - + # ========================================================================= # Check: Configuration files # ========================================================================= print() print(color("◆ Configuration Files", Colors.CYAN, Colors.BOLD)) - + # Check ~/.hermes/.env (primary location for user config) - env_path = HERMES_HOME / '.env' + env_path = HERMES_HOME / ".env" if env_path.exists(): check_ok("~/.hermes/.env file exists") - + # Check for common issues content = env_path.read_text() if _has_provider_env_config(content): @@ -159,7 +161,7 @@ def run_doctor(args): issues.append("Run 'hermes setup' to configure API keys") else: # Also check project root as fallback - fallback_env = PROJECT_ROOT / '.env' + fallback_env = PROJECT_ROOT / ".env" if fallback_env.exists(): check_ok(".env file exists (in project directory)") else: @@ -173,17 +175,17 @@ def run_doctor(args): else: check_info("Run 'hermes setup' to create one") issues.append("Run 'hermes setup' to create .env") - + # Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback) - config_path = HERMES_HOME / 'config.yaml' + config_path = HERMES_HOME / "config.yaml" if config_path.exists(): check_ok("~/.hermes/config.yaml exists") else: - fallback_config = PROJECT_ROOT / 'cli-config.yaml' + fallback_config = PROJECT_ROOT / "cli-config.yaml" if fallback_config.exists(): check_ok("cli-config.yaml exists (in project directory)") else: - example_config = PROJECT_ROOT / 'cli-config.yaml.example' + example_config = PROJECT_ROOT / "cli-config.yaml.example" if should_fix and example_config.exists(): config_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(str(example_config), str(config_path)) @@ -194,7 +196,7 @@ def run_doctor(args): manual_issues.append("Create ~/.hermes/config.yaml manually") else: check_warn("config.yaml not found", "(using defaults)") - + # ========================================================================= # Check: Auth providers # ========================================================================= @@ -202,7 +204,7 @@ def run_doctor(args): print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD)) try: - from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status + from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status nous_status = get_nous_auth_status() if nous_status.get("logged_in"): @@ -230,7 +232,7 @@ def run_doctor(args): # ========================================================================= print() print(color("◆ Directory Structure", Colors.CYAN, Colors.BOLD)) - + hermes_home = HERMES_HOME if hermes_home.exists(): check_ok("~/.hermes directory exists") @@ -241,7 +243,7 @@ def run_doctor(args): fixed_count += 1 else: check_warn("~/.hermes not found", "(will be created on first use)") - + # Check expected subdirectories expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"] for subdir_name in expected_subdirs: @@ -255,7 +257,7 @@ def run_doctor(args): fixed_count += 1 else: check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)") - + # Check for SOUL.md persona file soul_path = hermes_home / "SOUL.md" if soul_path.exists(): @@ -278,7 +280,7 @@ def run_doctor(args): ) check_ok("Created ~/.hermes/SOUL.md with basic template") fixed_count += 1 - + # Check memory directory memories_dir = hermes_home / "memories" if memories_dir.exists(): @@ -301,12 +303,13 @@ def run_doctor(args): memories_dir.mkdir(parents=True, exist_ok=True) check_ok("Created ~/.hermes/memories/") fixed_count += 1 - + # Check SQLite session store state_db_path = hermes_home / "state.db" if state_db_path.exists(): try: import sqlite3 + conn = sqlite3.connect(str(state_db_path)) cursor = conn.execute("SELECT COUNT(*) FROM sessions") count = cursor.fetchone()[0] @@ -316,26 +319,26 @@ def run_doctor(args): check_warn(f"~/.hermes/state.db exists but has issues: {e}") else: check_info("~/.hermes/state.db not created yet (will be created on first session)") - + # ========================================================================= # Check: External tools # ========================================================================= print() print(color("◆ External Tools", Colors.CYAN, Colors.BOLD)) - + # Git if shutil.which("git"): check_ok("git") else: check_warn("git not found", "(optional)") - + # ripgrep (optional, for faster file search) if shutil.which("rg"): check_ok("ripgrep (rg)", "(faster file search)") else: check_warn("ripgrep (rg) not found", "(file search uses grep fallback)") check_info("Install for faster search: sudo apt install ripgrep") - + # Docker (optional) terminal_env = os.getenv("TERMINAL_ENV", "local") if terminal_env == "docker": @@ -355,7 +358,7 @@ def run_doctor(args): check_ok("docker", "(optional)") else: check_warn("docker not found", "(optional)") - + # SSH (if using ssh backend) if terminal_env == "ssh": ssh_host = os.getenv("TERMINAL_SSH_HOST") @@ -364,7 +367,7 @@ def run_doctor(args): result = subprocess.run( ["ssh", "-o", "ConnectTimeout=5", "-o", "BatchMode=yes", ssh_host, "echo ok"], capture_output=True, - text=True + text=True, ) if result.returncode == 0: check_ok(f"SSH connection to {ssh_host}") @@ -374,7 +377,7 @@ def run_doctor(args): else: check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)") issues.append("Set TERMINAL_SSH_HOST in .env") - + # Daytona (if using daytona backend) if terminal_env == "daytona": daytona_key = os.getenv("DAYTONA_API_KEY") @@ -385,6 +388,7 @@ def run_doctor(args): issues.append("Set DAYTONA_API_KEY environment variable") try: from daytona import Daytona + check_ok("daytona SDK", "(installed)") except ImportError: check_fail("daytona SDK not installed", "(pip install daytona)") @@ -401,7 +405,7 @@ def run_doctor(args): check_warn("agent-browser not installed", "(run: npm install)") else: check_warn("Node.js not found", "(optional, needed for browser tools)") - + # npm audit for all Node.js packages if shutil.which("npm"): npm_dirs = [ @@ -415,9 +419,12 @@ def run_doctor(args): audit_result = subprocess.run( ["npm", "audit", "--json"], cwd=str(npm_dir), - capture_output=True, text=True, timeout=30, + capture_output=True, + text=True, + timeout=30, ) import json as _json + audit_data = _json.loads(audit_result.stdout) if audit_result.stdout.strip() else {} vuln_count = audit_data.get("metadata", {}).get("vulnerabilities", {}) critical = vuln_count.get("critical", 0) @@ -429,7 +436,7 @@ def run_doctor(args): elif critical > 0 or high > 0: check_warn( f"{label} deps", - f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)" + f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)", ) issues.append(f"{label} has {total} npm vulnerability(ies)") else: @@ -442,47 +449,50 @@ def run_doctor(args): # ========================================================================= print() print(color("◆ API Connectivity", Colors.CYAN, Colors.BOLD)) - + openrouter_key = os.getenv("OPENROUTER_API_KEY") if openrouter_key: print(" Checking OpenRouter API...", end="", flush=True) try: import httpx + response = httpx.get( - OPENROUTER_MODELS_URL, - headers={"Authorization": f"Bearer {openrouter_key}"}, - timeout=10 + OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10 ) if response.status_code == 200: print(f"\r {color('✓', Colors.GREEN)} OpenRouter API ") elif response.status_code == 401: - print(f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} ") + print( + f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} " + ) issues.append("Check OPENROUTER_API_KEY in .env") else: - print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} ") + print( + f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} " + ) except Exception as e: print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'({e})', Colors.DIM)} ") issues.append("Check network connectivity") else: check_warn("OpenRouter API", "(not configured)") - + anthropic_key = os.getenv("ANTHROPIC_API_KEY") if anthropic_key: print(" Checking Anthropic API...", end="", flush=True) try: import httpx + response = httpx.get( "https://api.anthropic.com/v1/models", - headers={ - "x-api-key": anthropic_key, - "anthropic-version": "2023-06-01" - }, - timeout=10 + headers={"x-api-key": anthropic_key, "anthropic-version": "2023-06-01"}, + timeout=10, ) if response.status_code == 200: print(f"\r {color('✓', Colors.GREEN)} Anthropic API ") elif response.status_code == 401: - print(f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} ") + print( + f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} " + ) else: msg = "(couldn't verify)" print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ") @@ -491,10 +501,15 @@ def run_doctor(args): # -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) -- _apikey_providers = [ - ("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"), - ("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"), - ("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"), - ("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"), + ( + "Z.AI / GLM", + ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), + "https://api.z.ai/api/paas/v4/models", + "GLM_BASE_URL", + ), + ("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"), + ("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"), + ("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"), ] for _pname, _env_vars, _default_url, _base_env in _apikey_providers: _key = "" @@ -507,6 +522,7 @@ def run_doctor(args): print(f" Checking {_pname} API...", end="", flush=True) try: import httpx + _base = os.getenv(_base_env, "") # Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com if not _base and _key.startswith("sk-kimi-"): @@ -526,7 +542,9 @@ def run_doctor(args): print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ") issues.append(f"Check {_env_vars[0]} in .env") else: - print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ") + print( + f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} " + ) except Exception as _e: print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ") @@ -535,7 +553,7 @@ def run_doctor(args): # ========================================================================= print() print(color("◆ Submodules", Colors.CYAN, Colors.BOLD)) - + # mini-swe-agent (terminal tool backend) mini_swe_dir = PROJECT_ROOT / "mini-swe-agent" if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists(): @@ -547,7 +565,7 @@ def run_doctor(args): issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent") else: check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)") - + # tinker-atropos (RL training backend) tinker_dir = PROJECT_ROOT / "tinker-atropos" if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists(): @@ -562,24 +580,24 @@ def run_doctor(args): check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})") else: check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)") - + # ========================================================================= # Check: Tool Availability # ========================================================================= print() print(color("◆ Tool Availability", Colors.CYAN, Colors.BOLD)) - + try: # Add project root to path for imports sys.path.insert(0, str(PROJECT_ROOT)) - from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS - + from model_tools import TOOLSET_REQUIREMENTS, check_tool_availability + available, unavailable = check_tool_availability() - + for tid in available: info = TOOLSET_REQUIREMENTS.get(tid, {}) check_ok(info.get("name", tid)) - + for item in unavailable: env_vars = item.get("missing_vars") or item.get("env_vars") or [] if env_vars: @@ -594,7 +612,7 @@ def run_doctor(args): issues.append("Run 'hermes setup' to configure missing API keys for full tool access") except Exception as e: check_warn("Could not check tool availability", f"({e})") - + # ========================================================================= # Check: Skills Hub # ========================================================================= @@ -608,6 +626,7 @@ def run_doctor(args): if lock_file.exists(): try: import json + lock_data = json.loads(lock_file.read_text()) count = len(lock_data.get("installed", {})) check_ok(f"Lock file OK ({count} hub-installed skill(s))") @@ -621,6 +640,7 @@ def run_doctor(args): check_warn("Skills Hub directory not initialized", "(run: hermes skills list)") from hermes_cli.config import get_env_value + github_token = get_env_value("GITHUB_TOKEN") or get_env_value("GH_TOKEN") if github_token: check_ok("GitHub token configured (authenticated API access)") @@ -656,5 +676,5 @@ def run_doctor(args): else: print(color("─" * 60, Colors.GREEN)) print(color(" All checks passed! 🎉", Colors.GREEN, Colors.BOLD)) - + print() diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 3d146546da..4b45497ebd 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -13,18 +13,24 @@ from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() +from hermes_cli.colors import Colors, color from hermes_cli.config import get_env_value, save_env_value from hermes_cli.setup import ( - print_header, print_info, print_success, print_warning, print_error, - prompt, prompt_choice, prompt_yes_no, + print_error, + print_header, + print_info, + print_success, + print_warning, + prompt, + prompt_choice, + prompt_yes_no, ) -from hermes_cli.colors import Colors, color - # ============================================================================= # Process Management (for manual gateway runs) # ============================================================================= + def find_gateway_pids() -> list: """Find PIDs of running gateway processes.""" pids = [] @@ -38,17 +44,16 @@ def find_gateway_pids() -> list: if is_windows(): # Windows: use wmic to search command lines result = subprocess.run( - ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], - capture_output=True, text=True + ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True ) # Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n" current_cmd = "" - for line in result.stdout.split('\n'): + for line in result.stdout.split("\n"): line = line.strip() if line.startswith("CommandLine="): - current_cmd = line[len("CommandLine="):] + current_cmd = line[len("CommandLine=") :] elif line.startswith("ProcessId="): - pid_str = line[len("ProcessId="):] + pid_str = line[len("ProcessId=") :] if any(p in current_cmd for p in patterns): try: pid = int(pid_str) @@ -58,14 +63,10 @@ def find_gateway_pids() -> list: pass current_cmd = "" else: - result = subprocess.run( - ["ps", "aux"], - capture_output=True, - text=True - ) - for line in result.stdout.split('\n'): + result = subprocess.run(["ps", "aux"], capture_output=True, text=True) + for line in result.stdout.split("\n"): # Skip grep and current process - if 'grep' in line or str(os.getpid()) in line: + if "grep" in line or str(os.getpid()) in line: continue for pattern in patterns: if pattern in line: @@ -88,7 +89,7 @@ def kill_gateway_processes(force: bool = False) -> int: """Kill any running gateway processes. Returns count killed.""" pids = find_gateway_pids() killed = 0 - + for pid in pids: try: if force and not is_windows(): @@ -101,18 +102,20 @@ def kill_gateway_processes(force: bool = False) -> int: pass except PermissionError: print(f"⚠ Permission denied to kill PID {pid}") - + return killed def is_linux() -> bool: - return sys.platform.startswith('linux') + return sys.platform.startswith("linux") + def is_macos() -> bool: - return sys.platform == 'darwin' + return sys.platform == "darwin" + def is_windows() -> bool: - return sys.platform == 'win32' + return sys.platform == "win32" # ============================================================================= @@ -122,12 +125,15 @@ def is_windows() -> bool: SERVICE_NAME = "hermes-gateway" SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration" + def get_systemd_unit_path() -> Path: return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service" + def get_launchd_plist_path() -> Path: return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist" + def get_python_path() -> str: if is_windows(): venv_python = PROJECT_ROOT / "venv" / "Scripts" / "python.exe" @@ -137,14 +143,16 @@ def get_python_path() -> str: return str(venv_python) return sys.executable + def get_hermes_cli_path() -> str: """Get the path to the hermes CLI.""" # Check if installed via pip import shutil + hermes_bin = shutil.which("hermes") if hermes_bin: return hermes_bin - + # Fallback to direct module execution return f"{get_python_path()} -m hermes_cli.main" @@ -153,8 +161,10 @@ def get_hermes_cli_path() -> str: # Systemd (Linux) # ============================================================================= + def generate_systemd_unit() -> str: import shutil + python_path = get_python_path() working_dir = str(PROJECT_ROOT) venv_dir = str(PROJECT_ROOT / "venv") @@ -163,7 +173,7 @@ def generate_systemd_unit() -> str: # Build a PATH that includes the venv, node_modules, and standard system dirs sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" - + hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main" return f"""[Unit] Description={SERVICE_DESCRIPTION} @@ -188,56 +198,62 @@ StandardError=journal WantedBy=default.target """ + def systemd_install(force: bool = False): unit_path = get_systemd_unit_path() - + if unit_path.exists() and not force: print(f"Service already installed at: {unit_path}") print("Use --force to reinstall") return - + unit_path.parent.mkdir(parents=True, exist_ok=True) print(f"Installing systemd service to: {unit_path}") unit_path.write_text(generate_systemd_unit()) - + subprocess.run(["systemctl", "--user", "daemon-reload"], check=True) subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True) - + print() print("✓ Service installed and enabled!") print() print("Next steps:") - print(f" hermes gateway start # Start the service") - print(f" hermes gateway status # Check status") + print(" hermes gateway start # Start the service") + print(" hermes gateway status # Check status") print(f" journalctl --user -u {SERVICE_NAME} -f # View logs") print() print("To enable lingering (keeps running after logout):") print(" sudo loginctl enable-linger $USER") + def systemd_uninstall(): subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False) subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False) - + unit_path = get_systemd_unit_path() if unit_path.exists(): unit_path.unlink() print(f"✓ Removed {unit_path}") - + subprocess.run(["systemctl", "--user", "daemon-reload"], check=True) print("✓ Service uninstalled") + def systemd_start(): subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True) print("✓ Service started") + def systemd_stop(): subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True) print("✓ Service stopped") + def systemd_restart(): subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True) print("✓ Service restarted") + def systemd_status(deep: bool = False): # Check if service unit file exists unit_path = get_systemd_unit_path() @@ -245,54 +261,45 @@ def systemd_status(deep: bool = False): print("✗ Gateway service is not installed") print(" Run: hermes gateway install") return - + # Show detailed status first - subprocess.run( - ["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"], - capture_output=False - ) - + subprocess.run(["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"], capture_output=False) + # Check if service is active - result = subprocess.run( - ["systemctl", "--user", "is-active", SERVICE_NAME], - capture_output=True, - text=True - ) - + result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True) + status = result.stdout.strip() - + if status == "active": print("✓ Gateway service is running") else: print("✗ Gateway service is stopped") print(" Run: hermes gateway start") - + if deep: print() print("Recent logs:") - subprocess.run([ - "journalctl", "--user", "-u", SERVICE_NAME, - "-n", "20", "--no-pager" - ]) + subprocess.run(["journalctl", "--user", "-u", SERVICE_NAME, "-n", "20", "--no-pager"]) # ============================================================================= # Launchd (macOS) # ============================================================================= + def generate_launchd_plist() -> str: python_path = get_python_path() working_dir = str(PROJECT_ROOT) log_dir = Path.home() / ".hermes" / "logs" log_dir.mkdir(parents=True, exist_ok=True) - + return f""" Label ai.hermes.gateway - + ProgramArguments {python_path} @@ -301,42 +308,43 @@ def generate_launchd_plist() -> str: gateway run - + WorkingDirectory {working_dir} - + RunAtLoad - + KeepAlive SuccessfulExit - + StandardOutPath {log_dir}/gateway.log - + StandardErrorPath {log_dir}/gateway.error.log """ + def launchd_install(force: bool = False): plist_path = get_launchd_plist_path() - + if plist_path.exists() and not force: print(f"Service already installed at: {plist_path}") print("Use --force to reinstall") return - + plist_path.parent.mkdir(parents=True, exist_ok=True) print(f"Installing launchd service to: {plist_path}") plist_path.write_text(generate_launchd_plist()) - + subprocess.run(["launchctl", "load", str(plist_path)], check=True) - + print() print("✓ Service installed and loaded!") print() @@ -344,41 +352,42 @@ def launchd_install(force: bool = False): print(" hermes gateway status # Check status") print(" tail -f ~/.hermes/logs/gateway.log # View logs") + def launchd_uninstall(): plist_path = get_launchd_plist_path() subprocess.run(["launchctl", "unload", str(plist_path)], check=False) - + if plist_path.exists(): plist_path.unlink() print(f"✓ Removed {plist_path}") - + print("✓ Service uninstalled") + def launchd_start(): subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True) print("✓ Service started") + def launchd_stop(): subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True) print("✓ Service stopped") + def launchd_restart(): launchd_stop() launchd_start() + def launchd_status(deep: bool = False): - result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], - capture_output=True, - text=True - ) - + result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True) + if result.returncode == 0: print("✓ Gateway service is loaded") print(result.stdout) else: print("✗ Gateway service is not loaded") - + if deep: log_file = Path.home() / ".hermes" / "logs" / "gateway.log" if log_file.exists(): @@ -391,9 +400,10 @@ def launchd_status(deep: bool = False): # Gateway Runner # ============================================================================= + def run_gateway(verbose: bool = False, replace: bool = False): """Run the gateway in foreground. - + Args: verbose: Enable verbose logging output. replace: If True, kill any existing gateway instance before starting. @@ -401,9 +411,9 @@ def run_gateway(verbose: bool = False, replace: bool = False): hasn't fully exited yet. """ sys.path.insert(0, str(PROJECT_ROOT)) - + from gateway.run import start_gateway - + print("┌─────────────────────────────────────────────────────────┐") print("│ ⚕ Hermes Gateway Starting... │") print("├─────────────────────────────────────────────────────────┤") @@ -411,7 +421,7 @@ def run_gateway(verbose: bool = False, replace: bool = False): print("│ Press Ctrl+C to stop │") print("└─────────────────────────────────────────────────────────┘") print() - + # Exit with code 1 if gateway fails to connect any platform, # so systemd Restart=on-failure will retry on transient errors success = asyncio.run(start_gateway(replace=replace)) @@ -438,13 +448,25 @@ _PLATFORMS = [ "4. To find your user ID: message @userinfobot — it replies with your numeric ID", ], "vars": [ - {"name": "TELEGRAM_BOT_TOKEN", "prompt": "Bot token", "password": True, - "help": "Paste the token from @BotFather (step 3 above)."}, - {"name": "TELEGRAM_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False, - "is_allowlist": True, - "help": "Paste your user ID from step 4 above."}, - {"name": "TELEGRAM_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False, - "help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat."}, + { + "name": "TELEGRAM_BOT_TOKEN", + "prompt": "Bot token", + "password": True, + "help": "Paste the token from @BotFather (step 3 above).", + }, + { + "name": "TELEGRAM_ALLOWED_USERS", + "prompt": "Allowed user IDs (comma-separated)", + "password": False, + "is_allowlist": True, + "help": "Paste your user ID from step 4 above.", + }, + { + "name": "TELEGRAM_HOME_CHANNEL", + "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", + "password": False, + "help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat.", + }, ], }, { @@ -466,13 +488,25 @@ _PLATFORMS = [ " then right-click your name → Copy ID", ], "vars": [ - {"name": "DISCORD_BOT_TOKEN", "prompt": "Bot token", "password": True, - "help": "Paste the token from step 2 above."}, - {"name": "DISCORD_ALLOWED_USERS", "prompt": "Allowed user IDs or usernames (comma-separated)", "password": False, - "is_allowlist": True, - "help": "Paste your user ID from step 5 above."}, - {"name": "DISCORD_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False, - "help": "Right-click a channel → Copy Channel ID (requires Developer Mode)."}, + { + "name": "DISCORD_BOT_TOKEN", + "prompt": "Bot token", + "password": True, + "help": "Paste the token from step 2 above.", + }, + { + "name": "DISCORD_ALLOWED_USERS", + "prompt": "Allowed user IDs or usernames (comma-separated)", + "password": False, + "is_allowlist": True, + "help": "Paste your user ID from step 5 above.", + }, + { + "name": "DISCORD_HOME_CHANNEL", + "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", + "password": False, + "help": "Right-click a channel → Copy Channel ID (requires Developer Mode).", + }, ], }, { @@ -497,13 +531,25 @@ _PLATFORMS = [ "8. Invite the bot to channels: /invite @YourBot", ], "vars": [ - {"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True, - "help": "Paste the bot token from step 3 above."}, - {"name": "SLACK_APP_TOKEN", "prompt": "App Token (xapp-...)", "password": True, - "help": "Paste the app-level token from step 4 above."}, - {"name": "SLACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False, - "is_allowlist": True, - "help": "Paste your member ID from step 7 above."}, + { + "name": "SLACK_BOT_TOKEN", + "prompt": "Bot Token (xoxb-...)", + "password": True, + "help": "Paste the bot token from step 3 above.", + }, + { + "name": "SLACK_APP_TOKEN", + "prompt": "App Token (xapp-...)", + "password": True, + "help": "Paste the app-level token from step 4 above.", + }, + { + "name": "SLACK_ALLOWED_USERS", + "prompt": "Allowed user IDs (comma-separated)", + "password": False, + "is_allowlist": True, + "help": "Paste your member ID from step 7 above.", + }, ], }, { @@ -582,14 +628,14 @@ def _setup_standard_platform(platform: dict): # Allowlist fields get special handling for the deny-by-default security model if var.get("is_allowlist"): - print_info(f" The gateway DENIES all users by default for security.") - print_info(f" Enter user IDs to create an allowlist, or leave empty") - print_info(f" and you'll be asked about open access next.") + print_info(" The gateway DENIES all users by default for security.") + print_info(" Enter user IDs to create an allowlist, or leave empty") + print_info(" and you'll be asked about open access next.") value = prompt(f" {var['prompt']}", password=False) if value: cleaned = value.replace(" ", "") save_env_value(var["name"], cleaned) - print_success(f" Saved — only these users can interact with the bot.") + print_success(" Saved — only these users can interact with the bot.") allowed_val_set = cleaned else: # No allowlist — ask about open access vs DM pairing @@ -618,7 +664,7 @@ def _setup_standard_platform(platform: dict): print_warning(f" Skipped — {label} won't work without this.") return else: - print_info(f" Skipped (can configure later)") + print_info(" Skipped (can configure later)") # If an allowlist was set and home channel wasn't, offer to reuse # the first user ID (common for Telegram DMs). @@ -636,8 +682,10 @@ def _setup_standard_platform(platform: dict): def _setup_whatsapp(): """Delegate to the existing WhatsApp setup flow.""" - from hermes_cli.main import cmd_whatsapp import argparse + + from hermes_cli.main import cmd_whatsapp + cmd_whatsapp(argparse.Namespace()) @@ -653,16 +701,10 @@ def _is_service_installed() -> bool: def _is_service_running() -> bool: """Check if the gateway service is currently running.""" if is_linux() and get_systemd_unit_path().exists(): - result = subprocess.run( - ["systemctl", "--user", "is-active", SERVICE_NAME], - capture_output=True, text=True - ) + result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True) return result.stdout.strip() == "active" elif is_macos() and get_launchd_plist_path().exists(): - result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], - capture_output=True, text=True - ) + result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True) return result.returncode == 0 # Check for manual processes return len(find_gateway_pids()) > 0 @@ -697,7 +739,7 @@ def _setup_signal(): print_info(" Docker: bbernhard/signal-cli-rest-api") print() print_info(" After installing, link your account and start the daemon:") - print_info(" signal-cli link -n \"HermesAgent\"") + print_info(' signal-cli link -n "HermesAgent"') print_info(" signal-cli --account +YOURNUMBER daemon --http 127.0.0.1:8080") print() @@ -715,6 +757,7 @@ def _setup_signal(): print_info(" Testing connection...") try: import httpx + resp = httpx.get(f"{url.rstrip('/')}/api/v1/check", timeout=10.0) if resp.status_code == 200: print_success(" signal-cli daemon is reachable!") @@ -779,7 +822,7 @@ def _setup_signal(): print_success("Signal configured!") print_info(f" URL: {url}") print_info(f" Account: {account}") - print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing") + print_info(" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing") print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}") @@ -841,11 +884,10 @@ def gateway_setup(): _setup_standard_platform(platform) # ── Post-setup: offer to install/restart gateway ── - any_configured = any( - bool(get_env_value(p["token_var"])) - for p in _PLATFORMS - if p["key"] != "whatsapp" - ) or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true" + any_configured = ( + any(bool(get_env_value(p["token_var"])) for p in _PLATFORMS if p["key"] != "whatsapp") + or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true" + ) if any_configured: print() @@ -878,7 +920,9 @@ def gateway_setup(): print() if is_linux() or is_macos(): platform_name = "systemd" if is_linux() else "launchd" - if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True): + if prompt_yes_no( + f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True + ): try: force = False if is_linux(): @@ -914,14 +958,15 @@ def gateway_setup(): # Main Command Handler # ============================================================================= + def gateway_command(args): """Handle gateway subcommands.""" - subcmd = getattr(args, 'gateway_command', None) - + subcmd = getattr(args, "gateway_command", None) + # Default to run if no subcommand if subcmd is None or subcmd == "run": - verbose = getattr(args, 'verbose', False) - replace = getattr(args, 'replace', False) + verbose = getattr(args, "verbose", False) + replace = getattr(args, "replace", False) run_gateway(verbose, replace=replace) return @@ -931,7 +976,7 @@ def gateway_command(args): # Service management commands if subcmd == "install": - force = getattr(args, 'force', False) + force = getattr(args, "force", False) if is_linux(): systemd_install(force) elif is_macos(): @@ -940,7 +985,7 @@ def gateway_command(args): print("Service installation not supported on this platform.") print("Run manually: hermes gateway run") sys.exit(1) - + elif subcmd == "uninstall": if is_linux(): systemd_uninstall() @@ -949,7 +994,7 @@ def gateway_command(args): else: print("Not supported on this platform.") sys.exit(1) - + elif subcmd == "start": if is_linux(): systemd_start() @@ -958,11 +1003,11 @@ def gateway_command(args): else: print("Not supported on this platform.") sys.exit(1) - + elif subcmd == "stop": # Try service first, fall back to killing processes directly service_available = False - + if is_linux() and get_systemd_unit_path().exists(): try: systemd_stop() @@ -975,7 +1020,7 @@ def gateway_command(args): service_available = True except subprocess.CalledProcessError: pass - + if not service_available: # Kill gateway processes directly killed = kill_gateway_processes() @@ -983,11 +1028,11 @@ def gateway_command(args): print(f"✓ Stopped {killed} gateway process(es)") else: print("✗ No gateway processes found") - + elif subcmd == "restart": # Try service first, fall back to killing and restarting service_available = False - + if is_linux() and get_systemd_unit_path().exists(): try: systemd_restart() @@ -1000,23 +1045,24 @@ def gateway_command(args): service_available = True except subprocess.CalledProcessError: pass - + if not service_available: # Manual restart: kill existing processes killed = kill_gateway_processes() if killed: print(f"✓ Stopped {killed} gateway process(es)") - + import time + time.sleep(2) - + # Start fresh print("Starting gateway...") run_gateway(verbose=False) - + elif subcmd == "status": - deep = getattr(args, 'deep', False) - + deep = getattr(args, "deep", False) + # Check for service first if is_linux() and get_systemd_unit_path().exists(): systemd_status(deep) diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 861cc038bf..ca89344526 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -28,7 +28,6 @@ import argparse import os import sys from pathlib import Path -from typing import Optional # Add project root to path PROJECT_ROOT = Path(__file__).parent.parent.resolve() @@ -36,14 +35,16 @@ sys.path.insert(0, str(PROJECT_ROOT)) # Load .env from ~/.hermes/.env first, then project root as dev fallback from dotenv import load_dotenv + from hermes_cli.config import get_env_path, get_hermes_home + _user_env = get_env_path() if _user_env.exists(): try: load_dotenv(dotenv_path=_user_env, encoding="utf-8") except UnicodeDecodeError: load_dotenv(dotenv_path=_user_env, encoding="latin-1") -load_dotenv(dotenv_path=PROJECT_ROOT / '.env', override=False) +load_dotenv(dotenv_path=PROJECT_ROOT / ".env", override=False) # Point mini-swe-agent at ~/.hermes/ so it shares our config os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(get_hermes_home())) @@ -59,13 +60,11 @@ logger = logging.getLogger(__name__) def _has_any_provider_configured() -> bool: """Check if at least one inference provider is usable.""" - from hermes_cli.config import get_env_path, get_hermes_home - from hermes_cli.auth import get_auth_status - # Check env vars (may be set by .env or shell). # OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.) # often don't require an API key. - from hermes_cli.auth import PROVIDER_REGISTRY + from hermes_cli.auth import PROVIDER_REGISTRY, get_auth_status + from hermes_cli.config import get_env_path, get_hermes_home # Collect all provider env vars provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"} @@ -95,6 +94,7 @@ def _has_any_provider_configured() -> bool: if auth_file.exists(): try: import json + auth = json.loads(auth_file.read_text()) active = auth.get("active_provider") if active: @@ -107,7 +107,7 @@ def _has_any_provider_configured() -> bool: return False -def _session_browse_picker(sessions: list) -> Optional[str]: +def _session_browse_picker(sessions: list) -> str | None: """Interactive curses-based session browser with live search filtering. Returns the selected session ID, or None if cancelled. @@ -180,10 +180,10 @@ def _session_browse_picker(sessions: list) -> Optional[str]: if curses.has_colors(): curses.start_color() curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) # selected + curses.init_pair(1, curses.COLOR_GREEN, -1) # selected curses.init_pair(2, curses.COLOR_YELLOW, -1) # header - curses.init_pair(3, curses.COLOR_CYAN, -1) # search - curses.init_pair(4, 8, -1) # dim + curses.init_pair(3, curses.COLOR_CYAN, -1) # search + curses.init_pair(4, 8, -1) # dim cursor = 0 scroll_offset = 0 @@ -251,10 +251,7 @@ def _session_browse_picker(sessions: list) -> Optional[str]: elif cursor >= scroll_offset + visible_rows: scroll_offset = cursor - visible_rows + 1 - for draw_i, i in enumerate(range( - scroll_offset, - min(len(filtered), scroll_offset + visible_rows) - )): + for draw_i, i in enumerate(range(scroll_offset, min(len(filtered), scroll_offset + visible_rows))): y = draw_i + 3 if y >= max_y - 1: break @@ -280,18 +277,19 @@ def _session_browse_picker(sessions: list) -> Optional[str]: else: footer = f" 0/{len(sessions)} sessions" try: - stdscr.addnstr(footer_y, 0, footer, max_x - 1, - curses.color_pair(4) if curses.has_colors() else curses.A_DIM) + stdscr.addnstr( + footer_y, 0, footer, max_x - 1, curses.color_pair(4) if curses.has_colors() else curses.A_DIM + ) except curses.error: pass stdscr.refresh() key = stdscr.getch() - if key in (curses.KEY_UP, ): + if key in (curses.KEY_UP,): if filtered: cursor = (cursor - 1) % len(filtered) - elif key in (curses.KEY_DOWN, ): + elif key in (curses.KEY_DOWN,): if filtered: cursor = (cursor + 1) % len(filtered) elif key in (curses.KEY_ENTER, 10, 13): @@ -317,7 +315,7 @@ def _session_browse_picker(sessions: list) -> Optional[str]: filtered = list(sessions) cursor = 0 scroll_offset = 0 - elif key == ord('q') and not search_text: + elif key == ord("q") and not search_text: return elif 32 <= key <= 126: # Printable character → add to search filter @@ -374,16 +372,17 @@ def _session_browse_picker(sessions: list) -> Optional[str]: return sessions[idx]["id"] print(f" Invalid selection. Enter 1-{len(sessions)} or q to cancel.") except ValueError: - print(f" Invalid input. Enter a number or q to cancel.") + print(" Invalid input. Enter a number or q to cancel.") except (KeyboardInterrupt, EOFError): print() return None -def _resolve_last_cli_session() -> Optional[str]: +def _resolve_last_cli_session() -> str | None: """Look up the most recent CLI session ID from SQLite. Returns None if unavailable.""" try: from hermes_state import SessionDB + db = SessionDB() sessions = db.search_sessions(source="cli", limit=1) db.close() @@ -394,7 +393,7 @@ def _resolve_last_cli_session() -> Optional[str]: return None -def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]: +def _resolve_session_by_name_or_id(name_or_id: str) -> str | None: """Resolve a session name (title) or ID to a session ID. - If it looks like a session ID (contains underscore + hex), try direct lookup first. @@ -403,6 +402,7 @@ def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]: """ try: from hermes_state import SessionDB + db = SessionDB() # Try as exact session ID first @@ -473,13 +473,14 @@ def cmd_chat(args): # Sync bundled skills on every CLI launch (fast -- skips unchanged skills) try: from tools.skills_sync import sync_skills + sync_skills(quiet=True) except Exception: pass # Import and run the CLI from cli import main as cli_main - + # Build kwargs from args kwargs = { "model": args.model, @@ -492,21 +493,22 @@ def cmd_chat(args): } # Filter out None values kwargs = {k: v for k, v in kwargs.items() if v is not None} - + cli_main(**kwargs) def cmd_gateway(args): """Gateway management commands.""" from hermes_cli.gateway import gateway_command + gateway_command(args) def cmd_whatsapp(args): """Set up WhatsApp: choose mode, configure, install bridge, pair via QR.""" - import os import subprocess from pathlib import Path + from hermes_cli.config import get_env_value, save_env_value print() @@ -634,6 +636,7 @@ def cmd_whatsapp(args): response = "n" if response.lower() in ("y", "yes"): import shutil + shutil.rmtree(session_dir, ignore_errors=True) session_dir.mkdir(parents=True, exist_ok=True) print(" ✓ Session cleared") @@ -692,18 +695,18 @@ def cmd_whatsapp(args): def cmd_setup(args): """Interactive setup wizard.""" from hermes_cli.setup import run_setup_wizard + run_setup_wizard(args) def cmd_model(args): """Select default model — starts with provider selection, then model picker.""" from hermes_cli.auth import ( - resolve_provider, get_provider_auth_state, PROVIDER_REGISTRY, - _prompt_model_selection, _save_model_choice, _update_config_for_provider, - resolve_nous_runtime_credentials, fetch_nous_models, AuthError, format_auth_error, - _login_nous, + AuthError, + format_auth_error, + resolve_provider, ) - from hermes_cli.config import load_config, save_config, get_env_value, save_env_value + from hermes_cli.config import get_env_value, load_config config = load_config() current_model = config.get("model") @@ -714,16 +717,13 @@ def cmd_model(args): # Read effective provider the same way the CLI does at startup: # config.yaml model.provider > env var > auto-detect import os + config_provider = None model_cfg = config.get("model") if isinstance(model_cfg, dict): config_provider = model_cfg.get("provider") - effective_provider = ( - os.getenv("HERMES_INFERENCE_PROVIDER") - or config_provider - or "auto" - ) + effective_provider = os.getenv("HERMES_INFERENCE_PROVIDER") or config_provider or "auto" try: active = resolve_provider(effective_provider) except AuthError as exc: @@ -833,12 +833,16 @@ def _prompt_provider_choice(choices): """Show provider selection menu. Returns index or None.""" try: from simple_term_menu import TerminalMenu + menu_items = [f" {c}" for c in choices] menu = TerminalMenu( - menu_items, cursor_index=0, - menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"), + menu_items, + cursor_index=0, + menu_cursor="-> ", + menu_cursor_style=("fg_green", "bold"), menu_highlight_style=("fg_green",), - cycle_cursor=True, clear_screen=False, + cycle_cursor=True, + clear_screen=False, title="Select provider:", ) idx = menu.show() @@ -891,6 +895,7 @@ def _model_flow_openrouter(config, current_model=""): print() from hermes_cli.models import model_ids + openrouter_models = model_ids() selected = _prompt_model_selection(openrouter_models, current_model=current_model) @@ -903,6 +908,7 @@ def _model_flow_openrouter(config, current_model=""): # Update config provider and deactivate any OAuth provider from hermes_cli.config import load_config, save_config + cfg = load_config() model = cfg.get("model") if isinstance(model, dict): @@ -917,14 +923,21 @@ def _model_flow_openrouter(config, current_model=""): def _model_flow_nous(config, current_model=""): """Nous Portal provider: ensure logged in, then pick model.""" + import argparse + from hermes_cli.auth import ( - get_provider_auth_state, _prompt_model_selection, _save_model_choice, - _update_config_for_provider, resolve_nous_runtime_credentials, - fetch_nous_models, AuthError, format_auth_error, - _login_nous, PROVIDER_REGISTRY, + PROVIDER_REGISTRY, + AuthError, + _login_nous, + _prompt_model_selection, + _save_model_choice, + _update_config_for_provider, + fetch_nous_models, + format_auth_error, + get_provider_auth_state, + resolve_nous_runtime_credentials, ) from hermes_cli.config import get_env_value, save_env_value - import argparse state = get_provider_auth_state("nous") if not state or not state.get("access_token"): @@ -932,9 +945,14 @@ def _model_flow_nous(config, current_model=""): print() try: mock_args = argparse.Namespace( - portal_url=None, inference_url=None, client_id=None, - scope=None, no_browser=False, timeout=15.0, - ca_bundle=None, insecure=False, + portal_url=None, + inference_url=None, + client_id=None, + scope=None, + no_browser=False, + timeout=15.0, + ca_bundle=None, + insecure=False, ) _login_nous(mock_args, PROVIDER_REGISTRY["nous"]) except SystemExit: @@ -962,9 +980,14 @@ def _model_flow_nous(config, current_model=""): print("Re-authenticating with Nous Portal...\n") try: mock_args = argparse.Namespace( - portal_url=None, inference_url=None, client_id=None, - scope=None, no_browser=False, timeout=15.0, - ca_bundle=None, insecure=False, + portal_url=None, + inference_url=None, + client_id=None, + scope=None, + no_browser=False, + timeout=15.0, + ca_bundle=None, + insecure=False, ) _login_nous(mock_args, PROVIDER_REGISTRY["nous"]) except Exception as login_exc: @@ -994,14 +1017,19 @@ def _model_flow_nous(config, current_model=""): def _model_flow_openai_codex(config, current_model=""): """OpenAI Codex provider: ensure logged in, then pick model.""" + import argparse + from hermes_cli.auth import ( - get_codex_auth_status, _prompt_model_selection, _save_model_choice, - _update_config_for_provider, _login_openai_codex, - PROVIDER_REGISTRY, DEFAULT_CODEX_BASE_URL, + DEFAULT_CODEX_BASE_URL, + PROVIDER_REGISTRY, + _login_openai_codex, + _prompt_model_selection, + _save_model_choice, + _update_config_for_provider, + get_codex_auth_status, ) from hermes_cli.codex_models import get_codex_model_ids from hermes_cli.config import get_env_value, save_env_value - import argparse status = get_codex_auth_status() if not status.get("logged_in"): @@ -1020,6 +1048,7 @@ def _model_flow_openai_codex(config, current_model=""): _codex_token = None try: from hermes_cli.auth import resolve_codex_runtime_credentials + _codex_creds = resolve_codex_runtime_credentials() _codex_token = _codex_creds.get("api_key") except Exception: @@ -1046,7 +1075,7 @@ def _model_flow_custom(config): so it appears in the provider menu on subsequent runs. """ from hermes_cli.auth import _save_model_choice, deactivate_provider - from hermes_cli.config import get_env_value, save_env_value, load_config, save_config + from hermes_cli.config import get_env_value, load_config, save_config, save_env_value current_url = get_env_value("OPENAI_BASE_URL") or "" current_key = get_env_value("OPENAI_API_KEY") or "" @@ -1130,6 +1159,7 @@ def _save_custom_provider(base_url, api_key="", model=""): # Auto-generate a name from the URL import re + clean = base_url.replace("https://", "").replace("http://", "").rstrip("/") # Remove /v1 suffix for cleaner names clean = re.sub(r"/v1/?$", "", clean) @@ -1152,7 +1182,7 @@ def _save_custom_provider(base_url, api_key="", model=""): providers.append(entry) cfg["custom_providers"] = providers save_config(cfg) - print(f" 💾 Saved to custom providers as \"{name}\" (edit in config.yaml)") + print(f' 💾 Saved to custom providers as "{name}" (edit in config.yaml)') def _remove_custom_provider(config): @@ -1180,11 +1210,15 @@ def _remove_custom_provider(config): try: from simple_term_menu import TerminalMenu + menu = TerminalMenu( - [f" {c}" for c in choices], cursor_index=0, - menu_cursor="-> ", menu_cursor_style=("fg_red", "bold"), + [f" {c}" for c in choices], + cursor_index=0, + menu_cursor="-> ", + menu_cursor_style=("fg_red", "bold"), menu_highlight_style=("fg_red",), - cycle_cursor=True, clear_screen=False, + cycle_cursor=True, + clear_screen=False, title="Select provider to remove:", ) idx = menu.show() @@ -1207,7 +1241,7 @@ def _remove_custom_provider(config): cfg["custom_providers"] = providers save_config(cfg) removed_name = removed.get("name", "unnamed") if isinstance(removed, dict) else str(removed) - print(f"✅ Removed \"{removed_name}\" from custom providers.") + print(f'✅ Removed "{removed_name}" from custom providers.') def _model_flow_named_custom(config, provider_info): @@ -1217,7 +1251,7 @@ def _model_flow_named_custom(config, provider_info): Otherwise probes the endpoint's /models API to let the user pick one. """ from hermes_cli.auth import _save_model_choice, deactivate_provider - from hermes_cli.config import save_env_value, load_config, save_config + from hermes_cli.config import load_config, save_config, save_env_value from hermes_cli.models import fetch_api_models name = provider_info["name"] @@ -1255,12 +1289,16 @@ def _model_flow_named_custom(config, provider_info): print(f"Found {len(models)} model(s):\n") try: from simple_term_menu import TerminalMenu + menu_items = [f" {m}" for m in models] + [" Cancel"] menu = TerminalMenu( - menu_items, cursor_index=0, - menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"), + menu_items, + cursor_index=0, + menu_cursor="-> ", + menu_cursor_style=("fg_green", "bold"), menu_highlight_style=("fg_green",), - cycle_cursor=True, clear_screen=False, + cycle_cursor=True, + clear_screen=False, title=f"Select model from {name}:", ) idx = menu.show() @@ -1349,10 +1387,12 @@ _PROVIDER_MODELS = { def _model_flow_api_key_provider(config, provider_id, current_model=""): """Generic flow for API-key providers (z.ai, Kimi, MiniMax).""" from hermes_cli.auth import ( - PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice, - _update_config_for_provider, deactivate_provider, + PROVIDER_REGISTRY, + _prompt_model_selection, + _save_model_choice, + deactivate_provider, ) - from hermes_cli.config import get_env_value, save_env_value, load_config, save_config + from hermes_cli.config import get_env_value, load_config, save_config, save_env_value pconfig = PROVIDER_REGISTRY[provider_id] key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else "" @@ -1433,36 +1473,42 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): def cmd_login(args): """Authenticate Hermes CLI with a provider.""" from hermes_cli.auth import login_command + login_command(args) def cmd_logout(args): """Clear provider authentication.""" from hermes_cli.auth import logout_command + logout_command(args) def cmd_status(args): """Show status of all components.""" from hermes_cli.status import show_status + show_status(args) def cmd_cron(args): """Cron job management.""" from hermes_cli.cron import cron_command + cron_command(args) def cmd_doctor(args): """Check configuration and dependencies.""" from hermes_cli.doctor import run_doctor + run_doctor(args) def cmd_config(args): """Configuration management.""" from hermes_cli.config import config_command + config_command(args) @@ -1470,13 +1516,14 @@ def cmd_version(args): """Show version.""" print(f"Hermes Agent v{__version__}") print(f"Project: {PROJECT_ROOT}") - + # Show Python version print(f"Python: {sys.version.split()[0]}") - + # Check for key dependencies try: import openai + print(f"OpenAI SDK: {openai.__version__}") except ImportError: print("OpenAI SDK: Not installed") @@ -1485,33 +1532,34 @@ def cmd_version(args): def cmd_uninstall(args): """Uninstall Hermes Agent.""" from hermes_cli.uninstall import run_uninstall + run_uninstall(args) def _update_via_zip(args): """Update Hermes Agent by downloading a ZIP archive. - - Used on Windows when git file I/O is broken (antivirus, NTFS filter + + Used on Windows when git file I/O is broken (antivirus, NTFS filter drivers causing 'Invalid argument' errors on file creation). """ import shutil import tempfile import zipfile from urllib.request import urlretrieve - + branch = "main" zip_url = f"https://github.com/NousResearch/hermes-agent/archive/refs/heads/{branch}.zip" - + print("→ Downloading latest version...") try: tmp_dir = tempfile.mkdtemp(prefix="hermes-update-") zip_path = os.path.join(tmp_dir, f"hermes-agent-{branch}.zip") urlretrieve(zip_url, zip_path) - + print("→ Extracting...") - with zipfile.ZipFile(zip_path, 'r') as zf: + with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(tmp_dir) - + # GitHub ZIPs extract to hermes-agent-/ extracted = os.path.join(tmp_dir, f"hermes-agent-{branch}") if not os.path.isdir(extracted): @@ -1521,9 +1569,9 @@ def _update_via_zip(args): if os.path.isdir(candidate) and d != "__MACOSX": extracted = candidate break - + # Copy updated files over existing installation, preserving venv/node_modules/.git - preserve = {'venv', 'node_modules', '.git', '__pycache__', '.env'} + preserve = {"venv", "node_modules", ".git", "__pycache__", ".env"} update_count = 0 for item in os.listdir(extracted): if item in preserve: @@ -1537,34 +1585,37 @@ def _update_via_zip(args): else: shutil.copy2(src, dst) update_count += 1 - + print(f"✓ Updated {update_count} items from ZIP") - + # Cleanup shutil.rmtree(tmp_dir, ignore_errors=True) - + except Exception as e: print(f"✗ ZIP update failed: {e}") sys.exit(1) - + # Reinstall Python dependencies print("→ Updating Python dependencies...") import subprocess + uv_bin = shutil.which("uv") if uv_bin: subprocess.run( [uv_bin, "pip", "install", "-e", ".", "--quiet"], - cwd=PROJECT_ROOT, check=True, - env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} + cwd=PROJECT_ROOT, + check=True, + env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}, ) else: venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip" if venv_pip.exists(): subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) - + # Sync skills try: from tools.skills_sync import sync_skills + print("→ Syncing bundled skills...") result = sync_skills(quiet=True) if result["copied"]: @@ -1579,38 +1630,42 @@ def _update_via_zip(args): print(" ✓ Skills are up to date") except Exception: pass - + print() print("✓ Update complete!") def cmd_update(args): """Update Hermes Agent to the latest version.""" - import subprocess import shutil - + import subprocess + print("⚕ Updating Hermes Agent...") print() - + # Try git-based update first, fall back to ZIP download on Windows # when git file I/O is broken (antivirus, NTFS filter drivers, etc.) use_zip_update = False - git_dir = PROJECT_ROOT / '.git' - + git_dir = PROJECT_ROOT / ".git" + if not git_dir.exists(): if sys.platform == "win32": use_zip_update = True else: print("✗ Not a git repository. Please reinstall:") - print(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash") + print( + " curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash" + ) sys.exit(1) - + # On Windows, git can fail with "unable to write loose object file: Invalid argument" # due to filesystem atomicity issues. Set the recommended workaround. if sys.platform == "win32" and git_dir.exists(): subprocess.run( ["git", "-c", "windows.appendAtomically=false", "config", "windows.appendAtomically", "false"], - cwd=PROJECT_ROOT, check=False, capture_output=True + cwd=PROJECT_ROOT, + check=False, + capture_output=True, ) if use_zip_update: @@ -1624,45 +1679,46 @@ def cmd_update(args): git_cmd = ["git"] if sys.platform == "win32": git_cmd = ["git", "-c", "windows.appendAtomically=false"] - + subprocess.run(git_cmd + ["fetch", "origin"], cwd=PROJECT_ROOT, check=True) - + # Get current branch result = subprocess.run( git_cmd + ["rev-parse", "--abbrev-ref", "HEAD"], cwd=PROJECT_ROOT, capture_output=True, text=True, - check=True + check=True, ) branch = result.stdout.strip() - + # Check if there are updates result = subprocess.run( git_cmd + ["rev-list", f"HEAD..origin/{branch}", "--count"], cwd=PROJECT_ROOT, capture_output=True, text=True, - check=True + check=True, ) commit_count = int(result.stdout.strip()) - + if commit_count == 0: print("✓ Already up to date!") return - + print(f"→ Found {commit_count} new commit(s)") print("→ Pulling updates...") subprocess.run(git_cmd + ["pull", "origin", branch], cwd=PROJECT_ROOT, check=True) - + # Reinstall Python dependencies (prefer uv for speed, fall back to pip) print("→ Updating Python dependencies...") uv_bin = shutil.which("uv") if uv_bin: subprocess.run( [uv_bin, "pip", "install", "-e", ".", "--quiet"], - cwd=PROJECT_ROOT, check=True, - env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} + cwd=PROJECT_ROOT, + check=True, + env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}, ) else: venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip" @@ -1670,20 +1726,22 @@ def cmd_update(args): subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) else: subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) - + # Check for Node.js deps if (PROJECT_ROOT / "package.json").exists(): import shutil + if shutil.which("npm"): print("→ Updating Node.js dependencies...") subprocess.run(["npm", "install", "--silent"], cwd=PROJECT_ROOT, check=False) - + print() print("✓ Code updated!") - + # Sync bundled skills (copies new, updates changed, respects user deletions) try: from tools.skills_sync import sync_skills + print() print("→ Syncing bundled skills...") result = sync_skills(quiet=True) @@ -1699,36 +1757,38 @@ def cmd_update(args): print(" ✓ Skills are up to date") except Exception as e: logger.debug("Skills sync during update failed: %s", e) - + # Check for config migrations print() print("→ Checking configuration for new options...") - + from hermes_cli.config import ( - get_missing_env_vars, get_missing_config_fields, - check_config_version, migrate_config + check_config_version, + get_missing_config_fields, + get_missing_env_vars, + migrate_config, ) - + missing_env = get_missing_env_vars(required_only=True) missing_config = get_missing_config_fields() current_ver, latest_ver = check_config_version() - + needs_migration = missing_env or missing_config or current_ver < latest_ver - + if needs_migration: print() if missing_env: print(f" ⚠️ {len(missing_env)} new required setting(s) need configuration") if missing_config: print(f" ℹ️ {len(missing_config)} new config option(s) available") - + print() response = input("Would you like to configure them now? [Y/n]: ").strip().lower() - - if response in ('', 'y', 'yes'): + + if response in ("", "y", "yes"): print() results = migrate_config(interactive=True, quiet=False) - + if results["env_added"] or results["config_added"]: print() print("✓ Configuration updated!") @@ -1737,22 +1797,26 @@ def cmd_update(args): print("Skipped. Run 'hermes config migrate' later to configure.") else: print(" ✓ Configuration is up to date") - + print() print("✓ Update complete!") - + # Auto-restart gateway if it's running as a systemd service try: check = subprocess.run( ["systemctl", "--user", "is-active", "hermes-gateway"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if check.stdout.strip() == "active": print() print("→ Gateway service is running — restarting to pick up changes...") restart = subprocess.run( ["systemctl", "--user", "restart", "hermes-gateway"], - capture_output=True, text=True, timeout=15, + capture_output=True, + text=True, + timeout=15, ) if restart.returncode == 0: print("✓ Gateway restarted.") @@ -1761,11 +1825,11 @@ def cmd_update(args): print(" Try manually: hermes gateway restart") except (FileNotFoundError, subprocess.TimeoutExpired): pass # No systemd (macOS, WSL1, etc.) — skip silently - + print() print("Tip: You can now select a provider and model:") print(" hermes model # Select provider and model") - + except subprocess.CalledProcessError as e: if sys.platform == "win32": print(f"⚠ Git update failed: {e}") @@ -1806,88 +1870,70 @@ Examples: For more help on a command: hermes --help -""" +""", ) - + + parser.add_argument("--version", "-V", action="store_true", help="Show version and exit") parser.add_argument( - "--version", "-V", - action="store_true", - help="Show version and exit" + "--resume", "-r", metavar="SESSION", default=None, help="Resume a previous session by ID or title" ) parser.add_argument( - "--resume", "-r", - metavar="SESSION", - default=None, - help="Resume a previous session by ID or title" - ) - parser.add_argument( - "--continue", "-c", + "--continue", + "-c", dest="continue_last", nargs="?", const=True, default=None, metavar="SESSION_NAME", - help="Resume a session by name, or the most recent if no name given" + help="Resume a session by name, or the most recent if no name given", ) parser.add_argument( - "--worktree", "-w", + "--worktree", + "-w", action="store_true", default=False, - help="Run in an isolated git worktree (for parallel agents)" + help="Run in an isolated git worktree (for parallel agents)", ) - + subparsers = parser.add_subparsers(dest="command", help="Command to run") - + # ========================================================================= # chat command # ========================================================================= chat_parser = subparsers.add_parser( "chat", help="Interactive chat with the agent", - description="Start an interactive chat session with Hermes Agent" - ) - chat_parser.add_argument( - "-q", "--query", - help="Single query (non-interactive mode)" - ) - chat_parser.add_argument( - "-m", "--model", - help="Model to use (e.g., anthropic/claude-sonnet-4)" - ) - chat_parser.add_argument( - "-t", "--toolsets", - help="Comma-separated toolsets to enable" + description="Start an interactive chat session with Hermes Agent", ) + chat_parser.add_argument("-q", "--query", help="Single query (non-interactive mode)") + chat_parser.add_argument("-m", "--model", help="Model to use (e.g., anthropic/claude-sonnet-4)") + chat_parser.add_argument("-t", "--toolsets", help="Comma-separated toolsets to enable") chat_parser.add_argument( "--provider", choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"], default=None, - help="Inference provider (default: auto)" + help="Inference provider (default: auto)", + ) + chat_parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + chat_parser.add_argument( + "--resume", "-r", metavar="SESSION_ID", help="Resume a previous session by ID (shown on exit)" ) chat_parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Verbose output" - ) - chat_parser.add_argument( - "--resume", "-r", - metavar="SESSION_ID", - help="Resume a previous session by ID (shown on exit)" - ) - chat_parser.add_argument( - "--continue", "-c", + "--continue", + "-c", dest="continue_last", nargs="?", const=True, default=None, metavar="SESSION_NAME", - help="Resume a session by name, or the most recent if no name given" + help="Resume a session by name, or the most recent if no name given", ) chat_parser.add_argument( - "--worktree", "-w", + "--worktree", + "-w", action="store_true", default=False, - help="Run in an isolated git worktree (for parallel agents on the same repo)" + help="Run in an isolated git worktree (for parallel agents on the same repo)", ) chat_parser.set_defaults(func=cmd_chat) @@ -1897,7 +1943,7 @@ For more help on a command: model_parser = subparsers.add_parser( "model", help="Select default model and provider", - description="Interactively select your inference provider and default model" + description="Interactively select your inference provider and default model", ) model_parser.set_defaults(func=cmd_model) @@ -1907,33 +1953,34 @@ For more help on a command: gateway_parser = subparsers.add_parser( "gateway", help="Messaging gateway management", - description="Manage the messaging gateway (Telegram, Discord, WhatsApp)" + description="Manage the messaging gateway (Telegram, Discord, WhatsApp)", ) gateway_subparsers = gateway_parser.add_subparsers(dest="gateway_command") - + # gateway run (default) gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground") gateway_run.add_argument("-v", "--verbose", action="store_true") - gateway_run.add_argument("--replace", action="store_true", - help="Replace any existing gateway instance (useful for systemd)") - + gateway_run.add_argument( + "--replace", action="store_true", help="Replace any existing gateway instance (useful for systemd)" + ) + # gateway start gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service") - + # gateway stop gateway_stop = gateway_subparsers.add_parser("stop", help="Stop gateway service") - + # gateway restart gateway_restart = gateway_subparsers.add_parser("restart", help="Restart gateway service") - + # gateway status gateway_status = gateway_subparsers.add_parser("status", help="Show gateway status") gateway_status.add_argument("--deep", action="store_true", help="Deep status check") - + # gateway install gateway_install = gateway_subparsers.add_parser("install", help="Install gateway as service") gateway_install.add_argument("--force", action="store_true", help="Force reinstall") - + # gateway uninstall gateway_uninstall = gateway_subparsers.add_parser("uninstall", help="Uninstall gateway service") @@ -1941,7 +1988,7 @@ For more help on a command: gateway_setup = gateway_subparsers.add_parser("setup", help="Configure messaging platforms") gateway_parser.set_defaults(func=cmd_gateway) - + # ========================================================================= # setup command # ========================================================================= @@ -1949,34 +1996,26 @@ For more help on a command: "setup", help="Interactive setup wizard", description="Configure Hermes Agent with an interactive wizard. " - "Run a specific section: hermes setup model|terminal|gateway|tools|agent" + "Run a specific section: hermes setup model|terminal|gateway|tools|agent", ) setup_parser.add_argument( "section", nargs="?", choices=["model", "terminal", "gateway", "tools", "agent"], default=None, - help="Run a specific setup section instead of the full wizard" + help="Run a specific setup section instead of the full wizard", ) setup_parser.add_argument( - "--non-interactive", - action="store_true", - help="Non-interactive mode (use defaults/env vars)" - ) - setup_parser.add_argument( - "--reset", - action="store_true", - help="Reset configuration to defaults" + "--non-interactive", action="store_true", help="Non-interactive mode (use defaults/env vars)" ) + setup_parser.add_argument("--reset", action="store_true", help="Reset configuration to defaults") setup_parser.set_defaults(func=cmd_setup) # ========================================================================= # whatsapp command # ========================================================================= whatsapp_parser = subparsers.add_parser( - "whatsapp", - help="Set up WhatsApp integration", - description="Configure WhatsApp and pair via QR code" + "whatsapp", help="Set up WhatsApp integration", description="Configure WhatsApp and pair via QR code" ) whatsapp_parser.set_defaults(func=cmd_whatsapp) @@ -1986,52 +2025,26 @@ For more help on a command: login_parser = subparsers.add_parser( "login", help="Authenticate with an inference provider", - description="Run OAuth device authorization flow for Hermes CLI" + description="Run OAuth device authorization flow for Hermes CLI", ) login_parser.add_argument( "--provider", choices=["nous", "openai-codex"], default=None, - help="Provider to authenticate with (default: nous)" + help="Provider to authenticate with (default: nous)", + ) + login_parser.add_argument("--portal-url", help="Portal base URL (default: production portal)") + login_parser.add_argument("--inference-url", help="Inference API base URL (default: production inference API)") + login_parser.add_argument("--client-id", default=None, help="OAuth client id to use (default: hermes-cli)") + login_parser.add_argument("--scope", default=None, help="OAuth scope to request") + login_parser.add_argument( + "--no-browser", action="store_true", help="Do not attempt to open the browser automatically" ) login_parser.add_argument( - "--portal-url", - help="Portal base URL (default: production portal)" - ) - login_parser.add_argument( - "--inference-url", - help="Inference API base URL (default: production inference API)" - ) - login_parser.add_argument( - "--client-id", - default=None, - help="OAuth client id to use (default: hermes-cli)" - ) - login_parser.add_argument( - "--scope", - default=None, - help="OAuth scope to request" - ) - login_parser.add_argument( - "--no-browser", - action="store_true", - help="Do not attempt to open the browser automatically" - ) - login_parser.add_argument( - "--timeout", - type=float, - default=15.0, - help="HTTP request timeout in seconds (default: 15)" - ) - login_parser.add_argument( - "--ca-bundle", - help="Path to CA bundle PEM file for TLS verification" - ) - login_parser.add_argument( - "--insecure", - action="store_true", - help="Disable TLS verification (testing only)" + "--timeout", type=float, default=15.0, help="HTTP request timeout in seconds (default: 15)" ) + login_parser.add_argument("--ca-bundle", help="Path to CA bundle PEM file for TLS verification") + login_parser.add_argument("--insecure", action="store_true", help="Disable TLS verification (testing only)") login_parser.set_defaults(func=cmd_login) # ========================================================================= @@ -2040,13 +2053,13 @@ For more help on a command: logout_parser = subparsers.add_parser( "logout", help="Clear authentication for an inference provider", - description="Remove stored credentials and reset provider config" + description="Remove stored credentials and reset provider config", ) logout_parser.add_argument( "--provider", choices=["nous", "openai-codex"], default=None, - help="Provider to log out from (default: active provider)" + help="Provider to log out from (default: active provider)", ) logout_parser.set_defaults(func=cmd_logout) @@ -2054,101 +2067,79 @@ For more help on a command: # status command # ========================================================================= status_parser = subparsers.add_parser( - "status", - help="Show status of all components", - description="Display status of Hermes Agent components" - ) - status_parser.add_argument( - "--all", - action="store_true", - help="Show all details (redacted for sharing)" - ) - status_parser.add_argument( - "--deep", - action="store_true", - help="Run deep checks (may take longer)" + "status", help="Show status of all components", description="Display status of Hermes Agent components" ) + status_parser.add_argument("--all", action="store_true", help="Show all details (redacted for sharing)") + status_parser.add_argument("--deep", action="store_true", help="Run deep checks (may take longer)") status_parser.set_defaults(func=cmd_status) - + # ========================================================================= # cron command # ========================================================================= - cron_parser = subparsers.add_parser( - "cron", - help="Cron job management", - description="Manage scheduled tasks" - ) + cron_parser = subparsers.add_parser("cron", help="Cron job management", description="Manage scheduled tasks") cron_subparsers = cron_parser.add_subparsers(dest="cron_command") - + # cron list cron_list = cron_subparsers.add_parser("list", help="List scheduled jobs") cron_list.add_argument("--all", action="store_true", help="Include disabled jobs") - + # cron status cron_subparsers.add_parser("status", help="Check if cron scheduler is running") - + # cron tick (mostly for debugging) cron_subparsers.add_parser("tick", help="Run due jobs once and exit") - + cron_parser.set_defaults(func=cmd_cron) - + # ========================================================================= # doctor command # ========================================================================= doctor_parser = subparsers.add_parser( - "doctor", - help="Check configuration and dependencies", - description="Diagnose issues with Hermes Agent setup" - ) - doctor_parser.add_argument( - "--fix", - action="store_true", - help="Attempt to fix issues automatically" + "doctor", help="Check configuration and dependencies", description="Diagnose issues with Hermes Agent setup" ) + doctor_parser.add_argument("--fix", action="store_true", help="Attempt to fix issues automatically") doctor_parser.set_defaults(func=cmd_doctor) - + # ========================================================================= # config command # ========================================================================= config_parser = subparsers.add_parser( - "config", - help="View and edit configuration", - description="Manage Hermes Agent configuration" + "config", help="View and edit configuration", description="Manage Hermes Agent configuration" ) config_subparsers = config_parser.add_subparsers(dest="config_command") - + # config show (default) config_show = config_subparsers.add_parser("show", help="Show current configuration") - + # config edit config_edit = config_subparsers.add_parser("edit", help="Open config file in editor") - + # config set config_set = config_subparsers.add_parser("set", help="Set a configuration value") config_set.add_argument("key", nargs="?", help="Configuration key (e.g., model, terminal.backend)") config_set.add_argument("value", nargs="?", help="Value to set") - + # config path config_path = config_subparsers.add_parser("path", help="Print config file path") - + # config env-path config_env = config_subparsers.add_parser("env-path", help="Print .env file path") - + # config check config_check = config_subparsers.add_parser("check", help="Check for missing/outdated config") - + # config migrate config_migrate = config_subparsers.add_parser("migrate", help="Update config with new options") - + config_parser.set_defaults(func=cmd_config) - + # ========================================================================= # pairing command # ========================================================================= pairing_parser = subparsers.add_parser( "pairing", help="Manage DM pairing codes for user authorization", - description="Approve or revoke user access via pairing codes" + description="Approve or revoke user access via pairing codes", ) pairing_sub = pairing_parser.add_subparsers(dest="pairing_action") @@ -2166,6 +2157,7 @@ For more help on a command: def cmd_pairing(args): from hermes_cli.pairing import pairing_command + pairing_command(args) pairing_parser.set_defaults(func=cmd_pairing) @@ -2176,16 +2168,19 @@ For more help on a command: skills_parser = subparsers.add_parser( "skills", help="Skills Hub — search, install, and manage skills from online registries", - description="Search, install, inspect, audit, and manage skills from GitHub, ClawHub, and other registries." + description="Search, install, inspect, audit, and manage skills from GitHub, ClawHub, and other registries.", ) skills_subparsers = skills_parser.add_subparsers(dest="skills_action") skills_browse = skills_subparsers.add_parser("browse", help="Browse all available skills (paginated)") skills_browse.add_argument("--page", type=int, default=1, help="Page number (default: 1)") skills_browse.add_argument("--size", type=int, default=20, help="Results per page (default: 20)") - skills_browse.add_argument("--source", default="all", - choices=["all", "official", "github", "clawhub", "lobehub"], - help="Filter by source (default: all)") + skills_browse.add_argument( + "--source", + default="all", + choices=["all", "official", "github", "clawhub", "lobehub"], + help="Filter by source (default: all)", + ) skills_search = skills_subparsers.add_parser("search", help="Search skill registries") skills_search.add_argument("query", help="Search query") @@ -2232,6 +2227,7 @@ For more help on a command: def cmd_skills(args): from hermes_cli.skills_hub import skills_command + skills_command(args) skills_parser.set_defaults(func=cmd_skills) @@ -2242,11 +2238,12 @@ For more help on a command: tools_parser = subparsers.add_parser( "tools", help="Configure which tools are enabled per platform", - description="Interactive tool configuration — enable/disable tools for CLI, Telegram, Discord, etc." + description="Interactive tool configuration — enable/disable tools for CLI, Telegram, Discord, etc.", ) def cmd_tools(args): from hermes_cli.tools_config import tools_command + tools_command(args) tools_parser.set_defaults(func=cmd_tools) @@ -2257,7 +2254,7 @@ For more help on a command: sessions_parser = subparsers.add_parser( "sessions", help="Manage session history (list, rename, export, prune, delete)", - description="View and manage the SQLite session store" + description="View and manage the SQLite session store", ) sessions_subparsers = sessions_parser.add_subparsers(dest="sessions_action") @@ -2275,7 +2272,9 @@ For more help on a command: sessions_delete.add_argument("--yes", "-y", action="store_true", help="Skip confirmation") sessions_prune = sessions_subparsers.add_parser("prune", help="Delete old sessions") - sessions_prune.add_argument("--older-than", type=int, default=90, help="Delete sessions older than N days (default: 90)") + sessions_prune.add_argument( + "--older-than", type=int, default=90, help="Delete sessions older than N days (default: 90)" + ) sessions_prune.add_argument("--source", help="Only prune sessions from this source") sessions_prune.add_argument("--yes", "-y", action="store_true", help="Skip confirmation") @@ -2294,8 +2293,10 @@ For more help on a command: def cmd_sessions(args): import json as _json + try: from hermes_state import SessionDB + db = SessionDB() except Exception as e: print(f"Error: Could not open session database: {e}") @@ -2308,8 +2309,8 @@ For more help on a command: if not sessions: print("No sessions found.") return - from datetime import datetime import time as _time + from datetime import datetime def _relative_time(ts): """Format a timestamp as relative time (e.g., '2h ago', 'yesterday').""" @@ -2415,6 +2416,7 @@ For more help on a command: # Launch hermes --resume by replacing the current process print(f"Resuming session: {selected_id}") import shutil + hermes_bin = shutil.which("hermes") if hermes_bin: os.execvp(hermes_bin, ["hermes", "--resume", selected_id]) @@ -2453,15 +2455,15 @@ For more help on a command: insights_parser = subparsers.add_parser( "insights", help="Show usage insights and analytics", - description="Analyze session history to show token usage, costs, tool patterns, and activity trends" + description="Analyze session history to show token usage, costs, tool patterns, and activity trends", ) insights_parser.add_argument("--days", type=int, default=30, help="Number of days to analyze (default: 30)") insights_parser.add_argument("--source", help="Filter by platform (cli, telegram, discord, etc.)") def cmd_insights(args): try: - from hermes_state import SessionDB from agent.insights import InsightsEngine + from hermes_state import SessionDB db = SessionDB() engine = InsightsEngine(db) @@ -2476,52 +2478,43 @@ For more help on a command: # ========================================================================= # version command # ========================================================================= - version_parser = subparsers.add_parser( - "version", - help="Show version information" - ) + version_parser = subparsers.add_parser("version", help="Show version information") version_parser.set_defaults(func=cmd_version) - + # ========================================================================= # update command # ========================================================================= update_parser = subparsers.add_parser( "update", help="Update Hermes Agent to the latest version", - description="Pull the latest changes from git and reinstall dependencies" + description="Pull the latest changes from git and reinstall dependencies", ) update_parser.set_defaults(func=cmd_update) - + # ========================================================================= # uninstall command # ========================================================================= uninstall_parser = subparsers.add_parser( "uninstall", help="Uninstall Hermes Agent", - description="Remove Hermes Agent from your system. Can keep configs/data for reinstall." + description="Remove Hermes Agent from your system. Can keep configs/data for reinstall.", ) uninstall_parser.add_argument( - "--full", - action="store_true", - help="Full uninstall - remove everything including configs and data" - ) - uninstall_parser.add_argument( - "--yes", "-y", - action="store_true", - help="Skip confirmation prompts" + "--full", action="store_true", help="Full uninstall - remove everything including configs and data" ) + uninstall_parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation prompts") uninstall_parser.set_defaults(func=cmd_uninstall) - + # ========================================================================= # Parse and execute # ========================================================================= args = parser.parse_args() - + # Handle --version flag if args.version: cmd_version(args) return - + # Handle top-level --resume / --continue as shortcut to chat if (args.resume or args.continue_last) and args.command is None: args.command = "chat" @@ -2534,7 +2527,7 @@ For more help on a command: args.worktree = False cmd_chat(args) return - + # Default to chat if no command specified if args.command is None: args.query = None @@ -2548,9 +2541,9 @@ For more help on a command: args.worktree = False cmd_chat(args) return - + # Execute the command - if hasattr(args, 'func'): + if hasattr(args, "func"): args.func(args) else: parser.print_help() diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 1fdde0900c..3d4076319d 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -8,26 +8,26 @@ Add, remove, or reorder entries here — both `hermes setup` and from __future__ import annotations import json -import urllib.request import urllib.error +import urllib.request from difflib import get_close_matches -from typing import Any, Optional +from typing import Any # (model_id, display description shown in menus) OPENROUTER_MODELS: list[tuple[str, str]] = [ - ("anthropic/claude-opus-4.6", "recommended"), - ("anthropic/claude-sonnet-4.5", ""), - ("openai/gpt-5.4-pro", ""), - ("openai/gpt-5.4", ""), - ("openai/gpt-5.3-codex", ""), - ("google/gemini-3-pro-preview", ""), - ("google/gemini-3-flash-preview", ""), - ("qwen/qwen3.5-plus-02-15", ""), - ("qwen/qwen3.5-35b-a3b", ""), - ("stepfun/step-3.5-flash", ""), - ("z-ai/glm-5", ""), - ("moonshotai/kimi-k2.5", ""), - ("minimax/minimax-m2.5", ""), + ("anthropic/claude-opus-4.6", "recommended"), + ("anthropic/claude-sonnet-4.5", ""), + ("openai/gpt-5.4-pro", ""), + ("openai/gpt-5.4", ""), + ("openai/gpt-5.3-codex", ""), + ("google/gemini-3-pro-preview", ""), + ("google/gemini-3-flash-preview", ""), + ("qwen/qwen3.5-plus-02-15", ""), + ("qwen/qwen3.5-35b-a3b", ""), + ("stepfun/step-3.5-flash", ""), + ("z-ai/glm-5", ""), + ("moonshotai/kimi-k2.5", ""), + ("minimax/minimax-m2.5", ""), ] _PROVIDER_MODELS: dict[str, list[str]] = { @@ -93,9 +93,7 @@ def menu_labels() -> list[str]: # All provider IDs and aliases that are valid for the provider:model syntax. _KNOWN_PROVIDER_NAMES: set[str] = ( - set(_PROVIDER_LABELS.keys()) - | set(_PROVIDER_ALIASES.keys()) - | {"openrouter", "custom"} + set(_PROVIDER_LABELS.keys()) | set(_PROVIDER_ALIASES.keys()) | {"openrouter", "custom"} ) @@ -107,8 +105,13 @@ def list_available_providers() -> list[dict[str, str]]: """ # Canonical providers in display order _PROVIDER_ORDER = [ - "openrouter", "nous", "openai-codex", - "zai", "kimi-coding", "minimax", "minimax-cn", + "openrouter", + "nous", + "openai-codex", + "zai", + "kimi-coding", + "minimax", + "minimax-cn", ] # Build reverse alias map aliases_for: dict[str, list[str]] = {} @@ -123,16 +126,19 @@ def list_available_providers() -> list[dict[str, str]]: has_creds = False try: from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=pid) has_creds = bool(runtime.get("api_key")) except Exception: pass - result.append({ - "id": pid, - "label": label, - "aliases": alias_list, - "authenticated": has_creds, - }) + result.append( + { + "id": pid, + "label": label, + "aliases": alias_list, + "authenticated": has_creds, + } + ) return result @@ -157,13 +163,13 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]: colon = stripped.find(":") if colon > 0: provider_part = stripped[:colon].strip().lower() - model_part = stripped[colon + 1:].strip() + model_part = stripped[colon + 1 :].strip() if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES: return (normalize_provider(provider_part), model_part) return (current_provider, stripped) -def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]: +def curated_models_for_provider(provider: str | None) -> list[tuple[str, str]]: """Return ``(model_id, description)`` tuples for a provider's curated list.""" normalized = normalize_provider(provider) if normalized == "openrouter": @@ -172,7 +178,7 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str] return [(m, "") for m in models] -def normalize_provider(provider: Optional[str]) -> str: +def normalize_provider(provider: str | None) -> str: """Normalize provider aliases to Hermes' canonical provider ids. Note: ``"auto"`` passes through unchanged — use @@ -183,7 +189,7 @@ def normalize_provider(provider: Optional[str]) -> str: return _PROVIDER_ALIASES.get(normalized, normalized) -def provider_model_ids(provider: Optional[str]) -> list[str]: +def provider_model_ids(provider: str | None) -> list[str]: """Return the best known model catalog for a provider.""" normalized = normalize_provider(provider) if normalized == "openrouter": @@ -196,10 +202,10 @@ def provider_model_ids(provider: Optional[str]) -> list[str]: def fetch_api_models( - api_key: Optional[str], - base_url: Optional[str], + api_key: str | None, + base_url: str | None, timeout: float = 5.0, -) -> Optional[list[str]]: +) -> list[str] | None: """Fetch the list of available model IDs from the provider's ``/models`` endpoint. Returns a list of model ID strings, or ``None`` if the endpoint could not @@ -225,10 +231,10 @@ def fetch_api_models( def validate_requested_model( model_name: str, - provider: Optional[str], + provider: str | None, *, - api_key: Optional[str] = None, - base_url: Optional[str] = None, + api_key: str | None = None, + base_url: str | None = None, ) -> dict[str, Any]: """ Validate a ``/model`` value for the active provider. @@ -286,10 +292,7 @@ def validate_requested_model( "accepted": False, "persist": False, "recognized": False, - "message": ( - f"Error: `{requested}` is not a valid model for this provider." - f"{suggestion_text}" - ), + "message": (f"Error: `{requested}` is not a valid model for this provider.{suggestion_text}"), } # api_models is None — couldn't reach API, fall back to catalog check diff --git a/hermes_cli/pairing.py b/hermes_cli/pairing.py index ecd9f61fcf..38fa0f36cc 100644 --- a/hermes_cli/pairing.py +++ b/hermes_cli/pairing.py @@ -8,6 +8,7 @@ Usage: hermes pairing clear-pending # Clear all expired/pending codes """ + def pairing_command(args): """Handle hermes pairing subcommands.""" from gateway.pairing import PairingStore @@ -72,10 +73,10 @@ def _cmd_approve(store, platform: str, code: str): name = result.get("user_name", "") display = f"{name} ({uid})" if name else uid print(f"\n Approved! User {display} on {platform} can now use the bot~") - print(f" They'll be recognized automatically on their next message.\n") + print(" They'll be recognized automatically on their next message.\n") else: print(f"\n Code '{code}' not found or expired for platform '{platform}'.") - print(f" Run 'hermes pairing list' to see pending codes.\n") + print(" Run 'hermes pairing list' to see pending codes.\n") def _cmd_revoke(store, platform: str, user_id: str): diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index bf86fa88b6..2b53f06c54 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -3,22 +3,22 @@ from __future__ import annotations import os -from typing import Any, Dict, Optional +from typing import Any from hermes_cli.auth import ( - AuthError, PROVIDER_REGISTRY, + AuthError, format_auth_error, - resolve_provider, - resolve_nous_runtime_credentials, - resolve_codex_runtime_credentials, resolve_api_key_provider_credentials, + resolve_codex_runtime_credentials, + resolve_nous_runtime_credentials, + resolve_provider, ) from hermes_cli.config import load_config from hermes_constants import OPENROUTER_BASE_URL -def _get_model_config() -> Dict[str, Any]: +def _get_model_config() -> dict[str, Any]: config = load_config() model_cfg = config.get("model") if isinstance(model_cfg, dict): @@ -28,7 +28,7 @@ def _get_model_config() -> Dict[str, Any]: return {} -def resolve_requested_provider(requested: Optional[str] = None) -> str: +def resolve_requested_provider(requested: str | None = None) -> str: """Resolve provider request from explicit arg, env, then config.""" if requested and requested.strip(): return requested.strip().lower() @@ -48,9 +48,9 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str: def _resolve_openrouter_runtime( *, requested_provider: str, - explicit_api_key: Optional[str] = None, - explicit_base_url: Optional[str] = None, -) -> Dict[str, Any]: + explicit_api_key: str | None = None, + explicit_base_url: str | None = None, +) -> dict[str, Any]: model_cfg = _get_model_config() cfg_base_url = model_cfg.get("base_url") if isinstance(model_cfg.get("base_url"), str) else "" cfg_provider = model_cfg.get("provider") if isinstance(model_cfg.get("provider"), str) else "" @@ -81,19 +81,9 @@ def _resolve_openrouter_runtime( # provider (issues #420, #560). _is_openrouter_url = "openrouter.ai" in base_url if _is_openrouter_url: - api_key = ( - explicit_api_key - or os.getenv("OPENROUTER_API_KEY") - or os.getenv("OPENAI_API_KEY") - or "" - ) + api_key = explicit_api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "" else: - api_key = ( - explicit_api_key - or os.getenv("OPENAI_API_KEY") - or os.getenv("OPENROUTER_API_KEY") - or "" - ) + api_key = explicit_api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") or "" source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config" @@ -108,10 +98,10 @@ def _resolve_openrouter_runtime( def resolve_runtime_provider( *, - requested: Optional[str] = None, - explicit_api_key: Optional[str] = None, - explicit_base_url: Optional[str] = None, -) -> Dict[str, Any]: + requested: str | None = None, + explicit_api_key: str | None = None, + explicit_base_url: str | None = None, +) -> dict[str, Any]: """Resolve runtime provider credentials for agent execution.""" requested_provider = resolve_requested_provider(requested) diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 5880b7ef35..b953c58331 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -15,84 +15,97 @@ import logging import os import sys from pathlib import Path -from typing import Optional, Dict, Any logger = logging.getLogger(__name__) PROJECT_ROOT = Path(__file__).parent.parent.resolve() # Import config helpers +from hermes_cli.colors import Colors, color from hermes_cli.config import ( - get_hermes_home, get_config_path, get_env_path, - load_config, save_config, save_env_value, get_env_value, - ensure_hermes_home, DEFAULT_CONFIG + ensure_hermes_home, + get_config_path, + get_env_path, + get_env_value, + get_hermes_home, + load_config, + save_config, + save_env_value, ) -from hermes_cli.colors import Colors, color def print_header(title: str): """Print a section header.""" print() print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD)) + def print_info(text: str): """Print info text.""" print(color(f" {text}", Colors.DIM)) + def print_success(text: str): """Print success message.""" print(color(f"✓ {text}", Colors.GREEN)) + def print_warning(text: str): """Print warning message.""" print(color(f"⚠ {text}", Colors.YELLOW)) + def print_error(text: str): """Print error message.""" print(color(f"✗ {text}", Colors.RED)) + def prompt(question: str, default: str = None, password: bool = False) -> str: """Prompt for input with optional default.""" if default: display = f"{question} [{default}]: " else: display = f"{question}: " - + try: if password: import getpass + value = getpass.getpass(color(display, Colors.YELLOW)) else: value = input(color(display, Colors.YELLOW)) - + return value.strip() or default or "" except (KeyboardInterrupt, EOFError): print() sys.exit(1) + def prompt_choice(question: str, choices: list, default: int = 0) -> int: """Prompt for a choice from a list with arrow key navigation. - + Escape keeps the current default (skips the question). Ctrl+C exits the wizard. """ print(color(question, Colors.YELLOW)) - + # Try to use interactive menu if available try: - from simple_term_menu import TerminalMenu import re - + + from simple_term_menu import TerminalMenu + # Strip emoji characters — simple_term_menu miscalculates visual # width of emojis, causing duplicated/garbled lines on redraw. _emoji_re = re.compile( "[\U0001f300-\U0001f9ff\U00002600-\U000027bf\U0000fe00-\U0000fe0f" - "\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+", flags=re.UNICODE + "\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+", + flags=re.UNICODE, ) menu_choices = [f" {_emoji_re.sub('', choice).strip()}" for choice in choices] - + print_info(" ↑/↓ Navigate Enter Select Esc Skip Ctrl+C Exit") - + terminal_menu = TerminalMenu( menu_choices, cursor_index=default, @@ -102,15 +115,15 @@ def prompt_choice(question: str, choices: list, default: int = 0) -> int: cycle_cursor=True, clear_screen=False, ) - + idx = terminal_menu.show() if idx is None: # User pressed Escape — keep current value - print_info(f" Skipped (keeping current)") + print_info(" Skipped (keeping current)") print() return default print() # Add newline after selection return idx - + except (ImportError, NotImplementedError): pass except Exception as e: @@ -141,22 +154,23 @@ def prompt_choice(question: str, choices: list, default: int = 0) -> int: print() sys.exit(1) + def prompt_yes_no(question: str, default: bool = True) -> bool: """Prompt for yes/no. Ctrl+C exits, empty input returns default.""" default_str = "Y/n" if default else "y/N" - + while True: try: value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower() except (KeyboardInterrupt, EOFError): print() sys.exit(1) - + if not value: return default - if value in ('y', 'yes'): + if value in ("y", "yes"): return True - if value in ('n', 'no'): + if value in ("n", "no"): return False print_error("Please enter 'y' or 'n'") @@ -164,40 +178,42 @@ def prompt_yes_no(question: str, default: bool = True) -> bool: def prompt_checklist(title: str, items: list, pre_selected: list = None) -> list: """ Display a multi-select checklist and return the indices of selected items. - + Each item in `items` is a display string. `pre_selected` is a list of indices that should be checked by default. A "Continue →" option is appended at the end — the user toggles items with Space and confirms with Enter on "Continue →". - + Falls back to a numbered toggle interface when simple_term_menu is unavailable. - + Returns: List of selected indices (not including the Continue option). """ if pre_selected is None: pre_selected = [] - + print(color(title, Colors.YELLOW)) print_info(" SPACE Toggle ENTER Confirm ESC Skip Ctrl+C Exit") print() - + try: - from simple_term_menu import TerminalMenu import re - + + from simple_term_menu import TerminalMenu + # Strip emoji characters from menu labels — simple_term_menu miscalculates # visual width of emojis on macOS, causing duplicated/garbled lines. _emoji_re = re.compile( "[\U0001f300-\U0001f9ff\U00002600-\U000027bf\U0000fe00-\U0000fe0f" - "\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+", flags=re.UNICODE + "\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+", + flags=re.UNICODE, ) menu_items = [f" {_emoji_re.sub('', item).strip()}" for item in items] - + # Map pre-selected indices to the actual menu entry strings preselected = [menu_items[i] for i in pre_selected if i < len(menu_items)] - + terminal_menu = TerminalMenu( menu_items, multi_select=True, @@ -212,26 +228,26 @@ def prompt_checklist(title: str, items: list, pre_selected: list = None) -> list cycle_cursor=True, clear_screen=False, ) - + terminal_menu.show() - + if terminal_menu.chosen_menu_entries is None: print_info(" Skipped (keeping current)") return list(pre_selected) - + selected = list(terminal_menu.chosen_menu_indices or []) return selected - + except (ImportError, NotImplementedError): # Fallback: numbered toggle interface (simple_term_menu doesn't support Windows) selected = set(pre_selected) - + while True: for i, item in enumerate(items): marker = color("[✓]", Colors.GREEN) if i in selected else "[ ]" print(f" {marker} {i + 1}. {item}") print() - + try: value = input(color(" Toggle # (or Enter to confirm): ", Colors.DIM)).strip() if not value: @@ -249,10 +265,10 @@ def prompt_checklist(title: str, items: list, pre_selected: list = None) -> list except (KeyboardInterrupt, EOFError): print() return [] - + # Clear and redraw (simple approach) print() - + return sorted(selected) @@ -279,9 +295,9 @@ def _prompt_api_key(var: dict): if value: save_env_value(var["name"], value) - print_success(f" ✓ Saved") + print_success(" ✓ Saved") else: - print_warning(f" Skipped (configure later with 'hermes setup')") + print_warning(" Skipped (configure later with 'hermes setup')") def _print_setup_summary(config: dict, hermes_home): @@ -289,103 +305,107 @@ def _print_setup_summary(config: dict, hermes_home): # Tool availability summary print() print_header("Tool Availability Summary") - + tool_status = [] - + # OpenRouter (required for vision, moa) - if get_env_value('OPENROUTER_API_KEY'): + if get_env_value("OPENROUTER_API_KEY"): tool_status.append(("Vision (image analysis)", True, None)) tool_status.append(("Mixture of Agents", True, None)) else: tool_status.append(("Vision (image analysis)", False, "OPENROUTER_API_KEY")) tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY")) - + # Firecrawl (web tools) - if get_env_value('FIRECRAWL_API_KEY') or get_env_value('FIRECRAWL_API_URL'): + if get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"): tool_status.append(("Web Search & Extract", True, None)) else: tool_status.append(("Web Search & Extract", False, "FIRECRAWL_API_KEY")) - + # Browser tools (local Chromium or Browserbase cloud) import shutil - _ab_found = shutil.which("agent-browser") or (Path(__file__).parent.parent / "node_modules" / ".bin" / "agent-browser").exists() - if get_env_value('BROWSERBASE_API_KEY'): + + _ab_found = ( + shutil.which("agent-browser") + or (Path(__file__).parent.parent / "node_modules" / ".bin" / "agent-browser").exists() + ) + if get_env_value("BROWSERBASE_API_KEY"): tool_status.append(("Browser Automation (Browserbase)", True, None)) elif _ab_found: tool_status.append(("Browser Automation (local)", True, None)) else: tool_status.append(("Browser Automation", False, "npm install -g agent-browser")) - + # FAL (image generation) - if get_env_value('FAL_KEY'): + if get_env_value("FAL_KEY"): tool_status.append(("Image Generation", True, None)) else: tool_status.append(("Image Generation", False, "FAL_KEY")) - + # TTS — show configured provider - tts_provider = config.get('tts', {}).get('provider', 'edge') - if tts_provider == 'elevenlabs' and get_env_value('ELEVENLABS_API_KEY'): + tts_provider = config.get("tts", {}).get("provider", "edge") + if tts_provider == "elevenlabs" and get_env_value("ELEVENLABS_API_KEY"): tool_status.append(("Text-to-Speech (ElevenLabs)", True, None)) - elif tts_provider == 'openai' and get_env_value('VOICE_TOOLS_OPENAI_KEY'): + elif tts_provider == "openai" and get_env_value("VOICE_TOOLS_OPENAI_KEY"): tool_status.append(("Text-to-Speech (OpenAI)", True, None)) else: tool_status.append(("Text-to-Speech (Edge TTS)", True, None)) - + # Tinker + WandB (RL training) - if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'): + if get_env_value("TINKER_API_KEY") and get_env_value("WANDB_API_KEY"): tool_status.append(("RL Training (Tinker)", True, None)) - elif get_env_value('TINKER_API_KEY'): + elif get_env_value("TINKER_API_KEY"): tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY")) else: tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY")) - + # Home Assistant - if get_env_value('HASS_TOKEN'): + if get_env_value("HASS_TOKEN"): tool_status.append(("Smart Home (Home Assistant)", True, None)) - + # Skills Hub - if get_env_value('GITHUB_TOKEN'): + if get_env_value("GITHUB_TOKEN"): tool_status.append(("Skills Hub (GitHub)", True, None)) else: tool_status.append(("Skills Hub (GitHub)", False, "GITHUB_TOKEN")) - + # Terminal (always available if system deps met) tool_status.append(("Terminal/Commands", True, None)) - + # Task planning (always available, in-memory) tool_status.append(("Task Planning (todo)", True, None)) - + # Skills (always available -- bundled skills + user-created skills) tool_status.append(("Skills (view, create, edit)", True, None)) - + # Print status available_count = sum(1 for _, avail, _ in tool_status if avail) total_count = len(tool_status) - + print_info(f"{available_count}/{total_count} tool categories available:") print() - + for name, available, missing_var in tool_status: if available: print(f" {color('✓', Colors.GREEN)} {name}") else: print(f" {color('✗', Colors.RED)} {name} {color(f'(missing {missing_var})', Colors.DIM)}") - + print() - + disabled_tools = [(name, var) for name, avail, var in tool_status if not avail] if disabled_tools: print_warning("Some tools are disabled. Run 'hermes setup tools' to configure them,") print_warning("or edit ~/.hermes/.env directly to add the missing API keys.") print() - + # Done banner print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN)) print(color("│ ✓ Setup Complete! │", Colors.GREEN)) print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN)) print() - + # Show file locations prominently print(color("📁 All your files are in ~/.hermes/:", Colors.CYAN, Colors.BOLD)) print() @@ -393,7 +413,7 @@ def _print_setup_summary(config: dict, hermes_home): print(f" {color('API Keys:', Colors.YELLOW)} {get_env_path()}") print(f" {color('Data:', Colors.YELLOW)} {hermes_home}/cron/, sessions/, logs/") print() - + print(color("─" * 60, Colors.DIM)) print() print(color("📝 To edit your configuration:", Colors.CYAN, Colors.BOLD)) @@ -407,13 +427,13 @@ def _print_setup_summary(config: dict, hermes_home): print(f" {color('hermes config', Colors.GREEN)} View current settings") print(f" {color('hermes config edit', Colors.GREEN)} Open config in your editor") print(f" {color('hermes config set KEY VALUE', Colors.GREEN)}") - print(f" Set a specific value") + print(" Set a specific value") print() - print(f" Or edit the files directly:") + print(" Or edit the files directly:") print(f" {color(f'nano {get_config_path()}', Colors.DIM)}") print(f" {color(f'nano {get_env_path()}', Colors.DIM)}") print() - + print(color("─" * 60, Colors.DIM)) print() print(color("🚀 Ready to go!", Colors.CYAN, Colors.BOLD)) @@ -426,45 +446,44 @@ def _print_setup_summary(config: dict, hermes_home): def _prompt_container_resources(config: dict): """Prompt for container resource settings (Docker, Singularity, Modal, Daytona).""" - terminal = config.setdefault('terminal', {}) + terminal = config.setdefault("terminal", {}) print() print_info("Container Resource Settings:") # Persistence - current_persist = terminal.get('container_persistent', True) + current_persist = terminal.get("container_persistent", True) persist_label = "yes" if current_persist else "no" print_info(" Persistent filesystem keeps files between sessions.") print_info(" Set to 'no' for ephemeral sandboxes that reset each time.") - persist_str = prompt(f" Persist filesystem across sessions? (yes/no)", persist_label) - terminal['container_persistent'] = persist_str.lower() in ('yes', 'true', 'y', '1') + persist_str = prompt(" Persist filesystem across sessions? (yes/no)", persist_label) + terminal["container_persistent"] = persist_str.lower() in ("yes", "true", "y", "1") # CPU - current_cpu = terminal.get('container_cpu', 1) - cpu_str = prompt(f" CPU cores", str(current_cpu)) + current_cpu = terminal.get("container_cpu", 1) + cpu_str = prompt(" CPU cores", str(current_cpu)) try: - terminal['container_cpu'] = float(cpu_str) + terminal["container_cpu"] = float(cpu_str) except ValueError: pass # Memory - current_mem = terminal.get('container_memory', 5120) - mem_str = prompt(f" Memory in MB (5120 = 5GB)", str(current_mem)) + current_mem = terminal.get("container_memory", 5120) + mem_str = prompt(" Memory in MB (5120 = 5GB)", str(current_mem)) try: - terminal['container_memory'] = int(mem_str) + terminal["container_memory"] = int(mem_str) except ValueError: pass # Disk - current_disk = terminal.get('container_disk', 51200) - disk_str = prompt(f" Disk in MB (51200 = 50GB)", str(current_disk)) + current_disk = terminal.get("container_disk", 51200) + disk_str = prompt(" Disk in MB (51200 = 50GB)", str(current_disk)) try: - terminal['container_disk'] = int(disk_str) + terminal["container_disk"] = int(disk_str) except ValueError: pass - # Tool categories and provider config are now in tools_config.py (shared # between `hermes tools` and `hermes setup tools`). @@ -473,14 +492,18 @@ def _prompt_container_resources(config: dict): # Section 1: Model & Provider Configuration # ============================================================================= + def setup_model_provider(config: dict): """Configure the inference provider and default model.""" from hermes_cli.auth import ( - get_active_provider, get_provider_auth_state, PROVIDER_REGISTRY, - format_auth_error, AuthError, fetch_nous_models, - resolve_nous_runtime_credentials, _update_config_for_provider, - _login_openai_codex, get_codex_auth_status, DEFAULT_CODEX_BASE_URL, + DEFAULT_CODEX_BASE_URL, + PROVIDER_REGISTRY, + _login_openai_codex, + _update_config_for_provider, detect_external_credentials, + fetch_nous_models, + get_active_provider, + resolve_nous_runtime_credentials, ) print_header("Inference Provider") @@ -497,14 +520,14 @@ def setup_model_provider(config: dict): print_info("Detected existing credentials:") for cred in detected_creds: if cred["provider"] == "openai-codex": - print_success(f" * {cred['label']} -- select \"OpenAI Codex\" to use it") + print_success(f' * {cred["label"]} -- select "OpenAI Codex" to use it') else: print_info(f" * {cred['label']}") print() # Detect if any provider is already configured has_any_provider = bool(active_oauth or existing_custom or existing_or) - + # Build "keep current" label if active_oauth and active_oauth in PROVIDER_REGISTRY: keep_label = f"Keep current ({PROVIDER_REGISTRY[active_oauth].name})" @@ -527,14 +550,14 @@ def setup_model_provider(config: dict): ] if keep_label: provider_choices.append(keep_label) - + # Default to "Keep current" if a provider exists, otherwise OpenRouter (most common) default_provider = len(provider_choices) - 1 if has_any_provider else 2 - + if not has_any_provider: print_warning("An inference provider is required for Hermes to work.") print() - + provider_idx = prompt_choice("Select your inference provider:", provider_choices, default_provider) # Track which provider was selected for model step @@ -550,12 +573,19 @@ def setup_model_provider(config: dict): print() try: - from hermes_cli.auth import _login_nous, ProviderConfig import argparse + + from hermes_cli.auth import _login_nous + mock_args = argparse.Namespace( - portal_url=None, inference_url=None, client_id=None, - scope=None, no_browser=False, timeout=15.0, - ca_bundle=None, insecure=False, + portal_url=None, + inference_url=None, + client_id=None, + scope=None, + no_browser=False, + timeout=15.0, + ca_bundle=None, + insecure=False, ) pconfig = PROVIDER_REGISTRY["nous"] _login_nous(mock_args, pconfig) @@ -563,7 +593,8 @@ def setup_model_provider(config: dict): # Fetch models for the selection step try: creds = resolve_nous_runtime_credentials( - min_key_ttl_seconds=5 * 60, timeout_seconds=15.0, + min_key_ttl_seconds=5 * 60, + timeout_seconds=15.0, ) nous_models = fetch_nous_models( inference_base_url=creds.get("base_url", ""), @@ -589,6 +620,7 @@ def setup_model_provider(config: dict): try: import argparse + mock_args = argparse.Namespace() _login_openai_codex(mock_args, PROVIDER_REGISTRY["openai-codex"]) # Clear custom endpoint vars that would override provider routing. @@ -636,10 +668,12 @@ def setup_model_provider(config: dict): # resolver doesn't keep returning the old provider (e.g. Codex). try: from hermes_cli.auth import deactivate_provider + deactivate_provider() except Exception: pass import yaml + config_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml" try: disk_cfg = {} @@ -663,8 +697,8 @@ def setup_model_provider(config: dict): current_url = get_env_value("OPENAI_BASE_URL") or "" current_key = get_env_value("OPENAI_API_KEY") - _raw_model = config.get('model', '') - current_model = _raw_model.get('default', '') if isinstance(_raw_model, dict) else (_raw_model or '') + _raw_model = config.get("model", "") + current_model = _raw_model.get("default", "") if isinstance(_raw_model, dict) else (_raw_model or "") if current_url: print_info(f" Current URL: {current_url}") @@ -680,13 +714,14 @@ def setup_model_provider(config: dict): if api_key: save_env_value("OPENAI_API_KEY", api_key) if model_name: - config['model'] = model_name + config["model"] = model_name save_env_value("LLM_MODEL", model_name) # Save provider and base_url to config.yaml so the gateway and CLI # both resolve the correct provider without relying on env-var heuristics. if base_url: import yaml + config_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml" try: disk_cfg = {} @@ -741,6 +776,7 @@ def setup_model_provider(config: dict): print() print_info("Detecting your z.ai endpoint...") from hermes_cli.auth import detect_zai_endpoint + detected = detect_zai_endpoint(api_key) if detected: zai_base_url = detected["base_url"] @@ -861,7 +897,15 @@ def setup_model_provider(config: dict): # ── OpenRouter API Key for tools (if not already set) ── # Tools (vision, web, MoA) use OpenRouter independently of the main provider. # Prompt for OpenRouter key if not set and a non-OpenRouter provider was chosen. - if selected_provider in ("nous", "openai-codex", "custom", "zai", "kimi-coding", "minimax", "minimax-cn") and not get_env_value("OPENROUTER_API_KEY"): + if selected_provider in ( + "nous", + "openai-codex", + "custom", + "zai", + "kimi-coding", + "minimax", + "minimax-cn", + ) and not get_env_value("OPENROUTER_API_KEY"): print() print_header("OpenRouter API Key (for tools)") print_info("Tools like vision analysis, web search, and MoA use OpenRouter") @@ -879,8 +923,12 @@ def setup_model_provider(config: dict): if selected_provider != "custom": # Custom already prompted for model name print_header("Default Model") - _raw_model = config.get('model', 'anthropic/claude-opus-4.6') - current_model = _raw_model.get('default', 'anthropic/claude-opus-4.6') if isinstance(_raw_model, dict) else (_raw_model or 'anthropic/claude-opus-4.6') + _raw_model = config.get("model", "anthropic/claude-opus-4.6") + current_model = ( + _raw_model.get("default", "anthropic/claude-opus-4.6") + if isinstance(_raw_model, dict) + else (_raw_model or "anthropic/claude-opus-4.6") + ) print_info(f"Current: {current_model}") if selected_provider == "nous" and nous_models: @@ -898,11 +946,11 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, len(model_choices) - 1) if model_idx < len(nous_models): - config['model'] = nous_models[model_idx] + config["model"] = nous_models[model_idx] elif model_idx == len(model_choices) - 2: # Custom model_name = prompt(" Model name") if model_name: - config['model'] = model_name + config["model"] = model_name # else: keep current elif selected_provider == "nous": @@ -912,10 +960,11 @@ def setup_model_provider(config: dict): print_info("Enter a Nous model name manually (e.g., claude-opus-4-6).") custom = prompt(f" Model name (Enter to keep '{current_model}')") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) elif selected_provider == "openai-codex": from hermes_cli.codex_models import get_codex_model_ids + codex_models = get_codex_model_ids() model_choices = codex_models + [f"Keep current ({current_model})"] default_codex = 0 @@ -926,12 +975,12 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, default_codex) if model_idx < len(codex_models): - config['model'] = codex_models[model_idx] + config["model"] = codex_models[model_idx] save_env_value("LLM_MODEL", codex_models[model_idx]) elif model_idx == len(codex_models): custom = prompt("Enter model name") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) _update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL) elif selected_provider == "zai": @@ -949,12 +998,12 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, keep_idx) if model_idx < len(zai_models): - config['model'] = zai_models[model_idx] + config["model"] = zai_models[model_idx] save_env_value("LLM_MODEL", zai_models[model_idx]) elif model_idx == len(zai_models): custom = prompt("Enter model name") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) # else: keep current elif selected_provider == "kimi-coding": @@ -967,12 +1016,12 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, keep_idx) if model_idx < len(kimi_models): - config['model'] = kimi_models[model_idx] + config["model"] = kimi_models[model_idx] save_env_value("LLM_MODEL", kimi_models[model_idx]) elif model_idx == len(kimi_models): custom = prompt("Enter model name") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) # else: keep current elif selected_provider in ("minimax", "minimax-cn"): @@ -985,17 +1034,17 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, keep_idx) if model_idx < len(minimax_models): - config['model'] = minimax_models[model_idx] + config["model"] = minimax_models[model_idx] save_env_value("LLM_MODEL", minimax_models[model_idx]) elif model_idx == len(minimax_models): custom = prompt("Enter model name") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) # else: keep current else: # Static list for OpenRouter / fallback (from canonical list) - from hermes_cli.models import model_ids, menu_labels + from hermes_cli.models import menu_labels, model_ids ids = model_ids() model_choices = menu_labels() + [ @@ -1007,18 +1056,18 @@ def setup_model_provider(config: dict): model_idx = prompt_choice("Select default model:", model_choices, keep_idx) if model_idx < len(ids): - config['model'] = ids[model_idx] + config["model"] = ids[model_idx] save_env_value("LLM_MODEL", ids[model_idx]) elif model_idx == len(ids): # Custom custom = prompt("Enter model name (e.g., anthropic/claude-opus-4.6)") if custom: - config['model'] = custom + config["model"] = custom save_env_value("LLM_MODEL", custom) # else: Keep current - _final_model = config.get('model', '') + _final_model = config.get("model", "") if _final_model: - _display = _final_model.get('default', _final_model) if isinstance(_final_model, dict) else _final_model + _display = _final_model.get("default", _final_model) if isinstance(_final_model, dict) else _final_model print_success(f"Model set to: {_display}") save_config(config) @@ -1028,6 +1077,7 @@ def setup_model_provider(config: dict): # Section 2: Terminal Backend Configuration # ============================================================================= + def setup_terminal_backend(config: dict): """Configure the terminal execution backend.""" import platform as _platform @@ -1038,7 +1088,7 @@ def setup_terminal_backend(config: dict): print_info("This affects tool execution, file access, and isolation.") print() - current_backend = config.get('terminal', {}).get('backend', 'local') + current_backend = config.get("terminal", {}).get("backend", "local") is_linux = _platform.system() == "Linux" # Build backend choices with descriptions @@ -1074,21 +1124,21 @@ def setup_terminal_backend(config: dict): print_info(f"Keeping current backend: {current_backend}") return - config.setdefault('terminal', {})['backend'] = selected_backend + config.setdefault("terminal", {})["backend"] = selected_backend if selected_backend == "local": print_success("Terminal backend: Local") print_info("Commands run directly on this machine.") - + # CWD for messaging print() print_info("Working directory for messaging sessions:") print_info(" When using Hermes via Telegram/Discord, this is where") print_info(" the agent starts. CLI mode always starts in the current directory.") - current_cwd = config.get('terminal', {}).get('cwd', '') + current_cwd = config.get("terminal", {}).get("cwd", "") cwd = prompt(" Messaging working directory", current_cwd or str(Path.home())) if cwd: - config['terminal']['cwd'] = cwd + config["terminal"]["cwd"] = cwd # Sudo support print() @@ -1114,9 +1164,9 @@ def setup_terminal_backend(config: dict): print_info(f"Docker found: {docker_bin}") # Docker image - current_image = config.get('terminal', {}).get('docker_image', 'python:3.11-slim') + current_image = config.get("terminal", {}).get("docker_image", "python:3.11-slim") image = prompt(" Docker image", current_image) - config['terminal']['docker_image'] = image + config["terminal"]["docker_image"] = image save_env_value("TERMINAL_DOCKER_IMAGE", image) _prompt_container_resources(config) @@ -1132,9 +1182,9 @@ def setup_terminal_backend(config: dict): else: print_info(f"Found: {sing_bin}") - current_image = config.get('terminal', {}).get('singularity_image', 'docker://python:3.11-slim') + current_image = config.get("terminal", {}).get("singularity_image", "docker://python:3.11-slim") image = prompt(" Container image", current_image) - config['terminal']['singularity_image'] = image + config["terminal"]["singularity_image"] = image save_env_value("TERMINAL_SINGULARITY_IMAGE", image) _prompt_container_resources(config) @@ -1150,16 +1200,17 @@ def setup_terminal_backend(config: dict): except ImportError: print_info("Installing swe-rex[modal]...") import subprocess + uv_bin = shutil.which("uv") if uv_bin: result = subprocess.run( [uv_bin, "pip", "install", "--python", sys.executable, "swe-rex[modal]"], - capture_output=True, text=True + capture_output=True, + text=True, ) else: result = subprocess.run( - [sys.executable, "-m", "pip", "install", "swe-rex[modal]"], - capture_output=True, text=True + [sys.executable, "-m", "pip", "install", "swe-rex[modal]"], capture_output=True, text=True ) if result.returncode == 0: print_success("swe-rex[modal] installed") @@ -1202,16 +1253,15 @@ def setup_terminal_backend(config: dict): except ImportError: print_info("Installing daytona SDK...") import subprocess + uv_bin = shutil.which("uv") if uv_bin: result = subprocess.run( - [uv_bin, "pip", "install", "--python", sys.executable, "daytona"], - capture_output=True, text=True + [uv_bin, "pip", "install", "--python", sys.executable, "daytona"], capture_output=True, text=True ) else: result = subprocess.run( - [sys.executable, "-m", "pip", "install", "daytona"], - capture_output=True, text=True + [sys.executable, "-m", "pip", "install", "daytona"], capture_output=True, text=True ) if result.returncode == 0: print_success("daytona SDK installed") @@ -1237,9 +1287,9 @@ def setup_terminal_backend(config: dict): print_success(" Configured") # Daytona image - current_image = config.get('terminal', {}).get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20') + current_image = config.get("terminal", {}).get("daytona_image", "nikolaik/python-nodejs:python3.11-nodejs20") image = prompt(" Sandbox image", current_image) - config['terminal']['daytona_image'] = image + config["terminal"]["daytona_image"] = image save_env_value("TERMINAL_DAYTONA_IMAGE", image) _prompt_container_resources(config) @@ -1277,6 +1327,7 @@ def setup_terminal_backend(config: dict): if host and prompt_yes_no(" Test SSH connection?", True): print_info(" Testing connection...") import subprocess + ssh_cmd = ["ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5"] if ssh_key: ssh_cmd.extend(["-i", ssh_key]) @@ -1303,27 +1354,28 @@ def setup_terminal_backend(config: dict): # Section 3: Agent Settings # ============================================================================= + def setup_agent_settings(config: dict): """Configure agent behavior: iterations, progress display, compression, session reset.""" # ── Max Iterations ── print_header("Agent Settings") - current_max = get_env_value('HERMES_MAX_ITERATIONS') or '90' + current_max = get_env_value("HERMES_MAX_ITERATIONS") or "90" print_info("Maximum tool-calling iterations per conversation.") print_info("Higher = more complex tasks, but costs more tokens.") print_info("Recommended: 30-60 for most tasks, 100+ for open exploration.") - + max_iter_str = prompt("Max iterations", current_max) try: max_iter = int(max_iter_str) if max_iter > 0: save_env_value("HERMES_MAX_ITERATIONS", str(max_iter)) - config['max_turns'] = max_iter + config["max_turns"] = max_iter print_success(f"Max iterations set to {max_iter}") except ValueError: print_warning("Invalid number, keeping current value") - + # ── Tool Progress Display ── print_info("") print_info("Tool Progress Display") @@ -1332,7 +1384,7 @@ def setup_agent_settings(config: dict): print_info(" new — Show tool name only when it changes (less noise)") print_info(" all — Show every tool call with a short preview") print_info(" verbose — Full args, results, and debug logs") - + current_mode = config.get("display", {}).get("tool_progress", "all") mode = prompt("Tool progress mode", current_mode) if mode.lower() in ("off", "new", "all", "verbose"): @@ -1348,18 +1400,18 @@ def setup_agent_settings(config: dict): print_header("Context Compression") print_info("Automatically summarizes old messages when context gets too long.") print_info("Higher threshold = compress later (use more context). Lower = compress sooner.") - - config.setdefault('compression', {})['enabled'] = True - - current_threshold = config.get('compression', {}).get('threshold', 0.85) + + config.setdefault("compression", {})["enabled"] = True + + current_threshold = config.get("compression", {}).get("threshold", 0.85) threshold_str = prompt("Compression threshold (0.5-0.95)", str(current_threshold)) try: threshold = float(threshold_str) if 0.5 <= threshold <= 0.95: - config['compression']['threshold'] = threshold + config["compression"]["threshold"] = threshold except ValueError: pass - + print_success(f"Context compression threshold set to {config['compression'].get('threshold', 0.85)}") # ── Session Reset Policy ── @@ -1373,7 +1425,7 @@ def setup_agent_settings(config: dict): print_info("") print_info("You can also manually reset anytime by typing /reset in chat.") print_info("") - + reset_choices = [ "Inactivity + daily reset (recommended - reset whichever comes first)", "Inactivity only (reset after N minutes of no messages)", @@ -1381,61 +1433,63 @@ def setup_agent_settings(config: dict): "Never auto-reset (context lives until /reset or context compression)", "Keep current settings", ] - - current_policy = config.get('session_reset', {}) - current_mode = current_policy.get('mode', 'both') - current_idle = current_policy.get('idle_minutes', 1440) - current_hour = current_policy.get('at_hour', 4) - + + current_policy = config.get("session_reset", {}) + current_mode = current_policy.get("mode", "both") + current_idle = current_policy.get("idle_minutes", 1440) + current_hour = current_policy.get("at_hour", 4) + default_reset = {"both": 0, "idle": 1, "daily": 2, "none": 3}.get(current_mode, 0) - + reset_idx = prompt_choice("Session reset mode:", reset_choices, default_reset) - - config.setdefault('session_reset', {}) - + + config.setdefault("session_reset", {}) + if reset_idx == 0: # Both - config['session_reset']['mode'] = 'both' + config["session_reset"]["mode"] = "both" idle_str = prompt(" Inactivity timeout (minutes)", str(current_idle)) try: idle_val = int(idle_str) if idle_val > 0: - config['session_reset']['idle_minutes'] = idle_val + config["session_reset"]["idle_minutes"] = idle_val except ValueError: pass hour_str = prompt(" Daily reset hour (0-23, local time)", str(current_hour)) try: hour_val = int(hour_str) if 0 <= hour_val <= 23: - config['session_reset']['at_hour'] = hour_val + config["session_reset"]["at_hour"] = hour_val except ValueError: pass - print_success(f"Sessions reset after {config['session_reset'].get('idle_minutes', 1440)} min idle or daily at {config['session_reset'].get('at_hour', 4)}:00") + print_success( + f"Sessions reset after {config['session_reset'].get('idle_minutes', 1440)} min idle or daily at {config['session_reset'].get('at_hour', 4)}:00" + ) elif reset_idx == 1: # Idle only - config['session_reset']['mode'] = 'idle' + config["session_reset"]["mode"] = "idle" idle_str = prompt(" Inactivity timeout (minutes)", str(current_idle)) try: idle_val = int(idle_str) if idle_val > 0: - config['session_reset']['idle_minutes'] = idle_val + config["session_reset"]["idle_minutes"] = idle_val except ValueError: pass print_success(f"Sessions reset after {config['session_reset'].get('idle_minutes', 1440)} min of inactivity") elif reset_idx == 2: # Daily only - config['session_reset']['mode'] = 'daily' + config["session_reset"]["mode"] = "daily" hour_str = prompt(" Daily reset hour (0-23, local time)", str(current_hour)) try: hour_val = int(hour_str) if 0 <= hour_val <= 23: - config['session_reset']['at_hour'] = hour_val + config["session_reset"]["at_hour"] = hour_val except ValueError: pass print_success(f"Sessions reset daily at {config['session_reset'].get('at_hour', 4)}:00") elif reset_idx == 3: # None - config['session_reset']['mode'] = 'none' + config["session_reset"]["mode"] = "none" print_info("Sessions will never auto-reset. Context is managed only by compression.") print_warning("Long conversations will grow in cost. Use /reset manually when needed.") # else: keep current (idx == 4) - + save_config(config) @@ -1443,6 +1497,7 @@ def setup_agent_settings(config: dict): # Section 4: Messaging Platforms (Gateway) # ============================================================================= + def setup_gateway(config: dict): """Configure messaging platform integrations.""" print_header("Messaging Platforms") @@ -1450,19 +1505,19 @@ def setup_gateway(config: dict): print() # ── Telegram ── - existing_telegram = get_env_value('TELEGRAM_BOT_TOKEN') + existing_telegram = get_env_value("TELEGRAM_BOT_TOKEN") if existing_telegram: print_info("Telegram: already configured") if prompt_yes_no("Reconfigure Telegram?", False): existing_telegram = None - + if not existing_telegram and prompt_yes_no("Set up Telegram bot?", False): print_info("Create a bot via @BotFather on Telegram") token = prompt("Telegram bot token", password=True) if token: save_env_value("TELEGRAM_BOT_TOKEN", token) print_success("Telegram token saved") - + # Allowed users (security) print() print_info("🔒 Security: Restrict who can use your bot") @@ -1476,13 +1531,13 @@ def setup_gateway(config: dict): print_success("Telegram allowlist configured - only listed users can use the bot") else: print_info("⚠️ No allowlist set - anyone who finds your bot can use it!") - + # Home channel setup with better guidance print() print_info("📬 Home Channel: where Hermes delivers cron job results,") print_info(" cross-platform messages, and notifications.") print_info(" For Telegram DMs, this is your user ID (same as above).") - + first_user_id = allowed_users.split(",")[0].strip() if allowed_users else "" if first_user_id: if prompt_yes_no(f"Use your user ID ({first_user_id}) as the home channel?", True): @@ -1497,10 +1552,10 @@ def setup_gateway(config: dict): home_channel = prompt("Home channel ID (leave empty to set later)") if home_channel: save_env_value("TELEGRAM_HOME_CHANNEL", home_channel) - + # Check/update existing Telegram allowlist elif existing_telegram: - existing_allowlist = get_env_value('TELEGRAM_ALLOWED_USERS') + existing_allowlist = get_env_value("TELEGRAM_ALLOWED_USERS") if not existing_allowlist: print_info("⚠️ Telegram has no user allowlist - anyone can use your bot!") if prompt_yes_no("Add allowed users now?", True): @@ -1509,21 +1564,21 @@ def setup_gateway(config: dict): if allowed_users: save_env_value("TELEGRAM_ALLOWED_USERS", allowed_users.replace(" ", "")) print_success("Telegram allowlist configured") - + # ── Discord ── - existing_discord = get_env_value('DISCORD_BOT_TOKEN') + existing_discord = get_env_value("DISCORD_BOT_TOKEN") if existing_discord: print_info("Discord: already configured") if prompt_yes_no("Reconfigure Discord?", False): existing_discord = None - + if not existing_discord and prompt_yes_no("Set up Discord bot?", False): print_info("Create a bot at https://discord.com/developers/applications") token = prompt("Discord bot token", password=True) if token: save_env_value("DISCORD_BOT_TOKEN", token) print_success("Discord token saved") - + # Allowed users (security) print() print_info("🔒 Security: Restrict who can use your bot") @@ -1539,7 +1594,7 @@ def setup_gateway(config: dict): print_success("Discord allowlist configured") else: print_info("⚠️ No allowlist set - anyone in servers with your bot can use it!") - + # Home channel setup with better guidance print() print_info("📬 Home Channel: where Hermes delivers cron job results,") @@ -1550,10 +1605,10 @@ def setup_gateway(config: dict): home_channel = prompt("Home channel ID (leave empty to set later with /set-home)") if home_channel: save_env_value("DISCORD_HOME_CHANNEL", home_channel) - + # Check/update existing Discord allowlist elif existing_discord: - existing_allowlist = get_env_value('DISCORD_ALLOWED_USERS') + existing_allowlist = get_env_value("DISCORD_ALLOWED_USERS") if not existing_allowlist: print_info("⚠️ Discord has no user allowlist - anyone can use your bot!") if prompt_yes_no("Add allowed users now?", True): @@ -1562,14 +1617,14 @@ def setup_gateway(config: dict): if allowed_users: save_env_value("DISCORD_ALLOWED_USERS", allowed_users.replace(" ", "")) print_success("Discord allowlist configured") - + # ── Slack ── - existing_slack = get_env_value('SLACK_BOT_TOKEN') + existing_slack = get_env_value("SLACK_BOT_TOKEN") if existing_slack: print_info("Slack: already configured") if prompt_yes_no("Reconfigure Slack?", False): existing_slack = None - + if not existing_slack and prompt_yes_no("Set up Slack bot?", False): print_info("Steps to create a Slack app:") print_info(" 1. Go to https://api.slack.com/apps → Create New App (from scratch)") @@ -1596,7 +1651,7 @@ def setup_gateway(config: dict): if app_token: save_env_value("SLACK_APP_TOKEN", app_token) print_success("Slack tokens saved") - + print() print_info("🔒 Security: Restrict who can use your bot") print_info(" To find a Member ID: click a user's name → View full profile → ⋮ → Copy member ID") @@ -1607,9 +1662,9 @@ def setup_gateway(config: dict): print_success("Slack allowlist configured") else: print_info("⚠️ No allowlist set - anyone in your workspace can use the bot!") - + # ── WhatsApp ── - existing_whatsapp = get_env_value('WHATSAPP_ENABLED') + existing_whatsapp = get_env_value("WHATSAPP_ENABLED") if not existing_whatsapp and prompt_yes_no("Set up WhatsApp?", False): print_info("WhatsApp connects via a built-in bridge (Baileys).") print_info("Requires Node.js. Run 'hermes whatsapp' for guided setup.") @@ -1619,13 +1674,13 @@ def setup_gateway(config: dict): print_success("WhatsApp enabled") print_info("Run 'hermes whatsapp' to choose your mode (separate bot number") print_info("or personal self-chat) and pair via QR code.") - + # ── Gateway Service Setup ── any_messaging = ( - get_env_value('TELEGRAM_BOT_TOKEN') - or get_env_value('DISCORD_BOT_TOKEN') - or get_env_value('SLACK_BOT_TOKEN') - or get_env_value('WHATSAPP_ENABLED') + get_env_value("TELEGRAM_BOT_TOKEN") + or get_env_value("DISCORD_BOT_TOKEN") + or get_env_value("SLACK_BOT_TOKEN") + or get_env_value("WHATSAPP_ENABLED") ) if any_messaging: print() @@ -1634,11 +1689,11 @@ def setup_gateway(config: dict): # Check if any home channels are missing missing_home = [] - if get_env_value('TELEGRAM_BOT_TOKEN') and not get_env_value('TELEGRAM_HOME_CHANNEL'): + if get_env_value("TELEGRAM_BOT_TOKEN") and not get_env_value("TELEGRAM_HOME_CHANNEL"): missing_home.append("Telegram") - if get_env_value('DISCORD_BOT_TOKEN') and not get_env_value('DISCORD_HOME_CHANNEL'): + if get_env_value("DISCORD_BOT_TOKEN") and not get_env_value("DISCORD_HOME_CHANNEL"): missing_home.append("Discord") - if get_env_value('SLACK_BOT_TOKEN') and not get_env_value('SLACK_HOME_CHANNEL'): + if get_env_value("SLACK_BOT_TOKEN") and not get_env_value("SLACK_HOME_CHANNEL"): missing_home.append("Slack") if missing_home: @@ -1652,13 +1707,19 @@ def setup_gateway(config: dict): # Offer to install the gateway as a system service import platform as _platform + _is_linux = _platform.system() == "Linux" _is_macos = _platform.system() == "Darwin" from hermes_cli.gateway import ( - _is_service_installed, _is_service_running, - systemd_install, systemd_start, systemd_restart, - launchd_install, launchd_start, launchd_restart, + _is_service_installed, + _is_service_running, + launchd_install, + launchd_restart, + launchd_start, + systemd_install, + systemd_restart, + systemd_start, ) service_installed = _is_service_installed() @@ -1685,7 +1746,9 @@ def setup_gateway(config: dict): print_error(f" Start failed: {e}") elif _is_linux or _is_macos: svc_name = "systemd" if _is_linux else "launchd" - if prompt_yes_no(f" Install the gateway as a {svc_name} service? (runs in background, starts on boot)", True): + if prompt_yes_no( + f" Install the gateway as a {svc_name} service? (runs in background, starts on boot)", True + ): try: if _is_linux: systemd_install(force=False) @@ -1717,17 +1780,19 @@ def setup_gateway(config: dict): # Section 5: Tool Configuration (delegates to unified tools_config.py) # ============================================================================= + def setup_tools(config: dict, first_install: bool = False): """Configure tools — delegates to the unified tools_command() in tools_config.py. - + Both `hermes setup tools` and `hermes tools` use the same flow: platform selection → toolset toggles → provider/API key configuration. - + Args: first_install: When True, uses the simplified first-install flow (no platform menu, prompts for all unconfigured API keys). """ from hermes_cli.tools_config import tools_command + tools_command(first_install=first_install, config=config) @@ -1746,7 +1811,7 @@ SETUP_SECTIONS = [ def run_setup_wizard(args): """Run the interactive setup wizard. - + Supports full, quick, and section-specific setup: hermes setup — full or quick (auto-detected) hermes setup model — just model/provider @@ -1756,12 +1821,12 @@ def run_setup_wizard(args): hermes setup agent — just agent settings """ ensure_hermes_home() - + config = load_config() hermes_home = get_hermes_home() - + # Check if a specific section was requested - section = getattr(args, 'section', None) + section = getattr(args, "section", None) if section: for key, label, func in SETUP_SECTIONS: if key == section: @@ -1774,20 +1839,21 @@ def run_setup_wizard(args): print() print_success(f"{label} configuration complete!") return - + print_error(f"Unknown setup section: {section}") print_info(f"Available sections: {', '.join(k for k, _, _ in SETUP_SECTIONS)}") return - + # Check if this is an existing installation with a provider configured from hermes_cli.auth import get_active_provider + active_provider = get_active_provider() is_existing = ( bool(get_env_value("OPENROUTER_API_KEY")) or bool(get_env_value("OPENAI_BASE_URL")) or active_provider is not None ) - + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA)) print(color("│ ⚕ Hermes Agent Setup Wizard │", Colors.MAGENTA)) @@ -1795,7 +1861,7 @@ def run_setup_wizard(args): print(color("│ Let's configure your Hermes Agent installation. │", Colors.MAGENTA)) print(color("│ Press Ctrl+C at any time to exit. │", Colors.MAGENTA)) print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA)) - + if is_existing: # ── Returning User Menu ── print() @@ -1891,8 +1957,9 @@ def run_setup_wizard(args): def _run_quick_setup(config: dict, hermes_home): """Quick setup — only configure items that are missing.""" from hermes_cli.config import ( - get_missing_env_vars, get_missing_config_fields, - check_config_version, migrate_config, + check_config_version, + get_missing_config_fields, + get_missing_env_vars, ) print() @@ -1927,12 +1994,12 @@ def _run_quick_setup(config: dict, hermes_home): print_info(f" {var.get('description', '')}") if var.get("url"): print_info(f" Get key at: {var['url']}") - + if var.get("password"): value = prompt(f" {var.get('prompt', var['name'])}", password=True) else: value = prompt(f" {var.get('prompt', var['name'])}") - + if value: save_env_value(var["name"], value) print_success(f" Saved {var['name']}") @@ -1988,8 +2055,7 @@ def _run_quick_setup(config: dict, hermes_home): platforms.setdefault(plat, []).append(var) platform_labels = [ - {"Telegram": "📱 Telegram", "Discord": "💬 Discord", "Slack": "💼 Slack"}.get(p, p) - for p in platform_order + {"Telegram": "📱 Telegram", "Discord": "💬 Discord", "Slack": "💼 Slack"}.get(p, p) for p in platform_order ] selected_indices = prompt_checklist( @@ -2014,9 +2080,9 @@ def _run_quick_setup(config: dict, hermes_home): value = prompt(f" {var.get('prompt', var['name'])}") if value: save_env_value(var["name"], value) - print_success(f" ✓ Saved") + print_success(" ✓ Saved") else: - print_warning(f" Skipped") + print_warning(" Skipped") print() # Handle missing config fields @@ -2025,7 +2091,7 @@ def _run_quick_setup(config: dict, hermes_home): print_info(f"Adding {len(missing_config)} new config option(s) with defaults...") for field in missing_config: print_success(f" Added {field['key']} = {field['default']}") - + # Update config version config["_config_version"] = latest_ver save_config(config) diff --git a/hermes_cli/skills_hub.py b/hermes_cli/skills_hub.py index 8b72fe4f46..5cf321b6ff 100644 --- a/hermes_cli/skills_hub.py +++ b/hermes_cli/skills_hub.py @@ -13,7 +13,6 @@ handler are thin wrappers that parse args and delegate. import json import shutil from pathlib import Path -from typing import Optional from rich.console import Console from rich.panel import Panel @@ -29,6 +28,7 @@ _console = Console() # Shared do_* functions # --------------------------------------------------------------------------- + def _resolve_short_name(name: str, sources, console: Console) -> str: """ Resolve a short skill name (e.g. 'pptx') to a full identifier by searching @@ -57,7 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str: table.add_column("Trust", style="dim") table.add_column("Identifier", style="bold cyan") for r in exact: - trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim") + trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get( + r.trust_level, "dim" + ) trust_label = "official" if r.source == "official" else r.trust_level table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier) c.print(table) @@ -76,8 +78,7 @@ def _resolve_short_name(name: str, sources, console: Console) -> str: return "" -def do_search(query: str, source: str = "all", limit: int = 10, - console: Optional[Console] = None) -> None: +def do_search(query: str, source: str = "all", limit: int = 10, console: Console | None = None) -> None: """Search registries and display results as a Rich table.""" from tools.skills_hub import GitHubAuth, create_source_router, unified_search @@ -111,18 +112,19 @@ def do_search(query: str, source: str = "all", limit: int = 10, ) c.print(table) - c.print("[dim]Use: hermes skills inspect to preview, " - "hermes skills install to install[/]\n") + c.print( + "[dim]Use: hermes skills inspect to preview, hermes skills install to install[/]\n" + ) -def do_browse(page: int = 1, page_size: int = 20, source: str = "all", - console: Optional[Console] = None) -> None: +def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console: Console | None = None) -> None: """Browse all available skills across registries, paginated. Official skills are always shown first, regardless of source filter. """ from tools.skills_hub import ( - GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta, + GitHubAuth, + create_source_router, ) # Clamp page_size to safe range @@ -136,8 +138,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", # Collect results from all (or filtered) sources # Use empty query to get everything; per-source limits prevent overload _TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1} - _PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50, - "claude-marketplace": 50, "lobehub": 50} + _PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50, "claude-marketplace": 50, "lobehub": 50} all_results: list = [] source_counts: dict = {} @@ -168,11 +169,13 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", deduped = list(seen.values()) # Sort: official first, then by trust level (desc), then alphabetically - deduped.sort(key=lambda r: ( - -_TRUST_RANK.get(r.trust_level, 0), - r.source != "official", - r.name.lower(), - )) + deduped.sort( + key=lambda r: ( + -_TRUST_RANK.get(r.trust_level, 0), + r.source != "official", + r.name.lower(), + ) + ) # Paginate total = len(deduped) @@ -187,8 +190,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", # Build header source_label = f"— {source}" if source != "all" else "— all sources" - c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]" - f" [dim]({total} skills, page {page}/{total_pages})[/]") + c.print(f"\n[bold]Skills Hub — Browse {source_label}[/] [dim]({total} skills, page {page}/{total_pages})[/]") if official_count > 0 and page == 1: c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]") c.print() @@ -202,8 +204,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", table.add_column("Trust", width=10) for i, r in enumerate(page_items, start=start + 1): - trust_style = {"builtin": "bright_cyan", "trusted": "green", - "community": "yellow"}.get(r.trust_level, "dim") + trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim") trust_label = "★ official" if r.source == "official" else r.trust_level desc = r.description[:50] @@ -235,18 +236,22 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())] c.print(f" [dim]Sources: {', '.join(parts)}[/]") - c.print("[dim]Use: hermes skills inspect to preview, " - "hermes skills install to install[/]\n") - - -def do_install(identifier: str, category: str = "", force: bool = False, - console: Optional[Console] = None) -> None: - """Fetch, quarantine, scan, confirm, and install a skill.""" - from tools.skills_hub import ( - GitHubAuth, create_source_router, ensure_hub_dirs, - quarantine_bundle, install_from_quarantine, HubLockFile, + c.print( + "[dim]Use: hermes skills inspect to preview, hermes skills install to install[/]\n" + ) + + +def do_install(identifier: str, category: str = "", force: bool = False, console: Console | None = None) -> None: + """Fetch, quarantine, scan, confirm, and install a skill.""" + from tools.skills_guard import format_scan_report, scan_skill, should_allow_install + from tools.skills_hub import ( + GitHubAuth, + HubLockFile, + create_source_router, + ensure_hub_dirs, + install_from_quarantine, + quarantine_bundle, ) - from tools.skills_guard import scan_skill, should_allow_install, format_scan_report c = console or _console ensure_hub_dirs() @@ -304,33 +309,43 @@ def do_install(identifier: str, category: str = "", force: bool = False, # Clean up quarantine shutil.rmtree(q_path, ignore_errors=True) from tools.skills_hub import append_audit_log - append_audit_log("BLOCKED", bundle.name, bundle.source, - bundle.trust_level, result.verdict, - f"{len(result.findings)}_findings") + + append_audit_log( + "BLOCKED", + bundle.name, + bundle.source, + bundle.trust_level, + result.verdict, + f"{len(result.findings)}_findings", + ) return # Confirm with user — show appropriate warning based on source if not force: c.print() if bundle.source == "official": - c.print(Panel( - "[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n" - "It ships with hermes-agent but is not activated by default.\n" - "Installing will copy it to your skills directory where the agent can use it.\n\n" - f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", - title="Official Skill", - border_style="bright_cyan", - )) + c.print( + Panel( + "[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n" + "It ships with hermes-agent but is not activated by default.\n" + "Installing will copy it to your skills directory where the agent can use it.\n\n" + f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", + title="Official Skill", + border_style="bright_cyan", + ) + ) else: - c.print(Panel( - "[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n" - "External skills can contain instructions that influence agent behavior,\n" - "shell commands, and scripts. Even after automated scanning, you should\n" - "review the installed files before use.\n\n" - f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", - title="Disclaimer", - border_style="yellow", - )) + c.print( + Panel( + "[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n" + "External skills can contain instructions that influence agent behavior,\n" + "shell commands, and scripts. Even after automated scanning, you should\n" + "review the installed files before use.\n\n" + f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", + title="Disclaimer", + border_style="yellow", + ) + ) c.print(f"[bold]Install '{bundle.name}'?[/]") try: answer = input("Confirm [y/N]: ").strip().lower() @@ -344,11 +359,12 @@ def do_install(identifier: str, category: str = "", force: bool = False, # Install install_dir = install_from_quarantine(q_path, bundle.name, category, bundle, result) from tools.skills_hub import SKILLS_DIR + c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}") c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n") -def do_inspect(identifier: str, console: Optional[Console] = None) -> None: +def do_inspect(identifier: str, console: Console | None = None) -> None: """Preview a skill's SKILL.md content without installing.""" from tools.skills_hub import GitHubAuth, create_source_router @@ -406,7 +422,7 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None: c.print() -def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None: +def do_list(source_filter: str = "all", console: Console | None = None) -> None: """List installed skills, distinguishing builtins from hub-installed.""" from tools.skills_hub import HubLockFile, ensure_hub_dirs from tools.skills_tool import _find_all_skills @@ -446,14 +462,13 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]") c.print(table) - c.print(f"[dim]{len(hub_installed)} hub-installed, " - f"{len(all_skills) - len(hub_installed)} builtin[/]\n") + c.print(f"[dim]{len(hub_installed)} hub-installed, {len(all_skills) - len(hub_installed)} builtin[/]\n") -def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> None: +def do_audit(name: str | None = None, console: Console | None = None) -> None: """Re-run security scan on installed hub skills.""" - from tools.skills_hub import HubLockFile, SKILLS_DIR - from tools.skills_guard import scan_skill, format_scan_report + from tools.skills_guard import format_scan_report, scan_skill + from tools.skills_hub import SKILLS_DIR, HubLockFile c = console or _console lock = HubLockFile() @@ -483,7 +498,7 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N c.print() -def do_uninstall(name: str, console: Optional[Console] = None) -> None: +def do_uninstall(name: str, console: Console | None = None) -> None: """Remove a hub-installed skill with confirmation.""" from tools.skills_hub import uninstall_skill @@ -505,7 +520,7 @@ def do_uninstall(name: str, console: Optional[Console] = None) -> None: c.print(f"[bold red]Error:[/] {msg}\n") -def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> None: +def do_tap(action: str, repo: str = "", console: Console | None = None) -> None: """Manage taps (custom GitHub repo sources).""" from tools.skills_hub import TapsManager @@ -547,11 +562,10 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No c.print(f"[bold red]Unknown tap action:[/] {action}. Use: list, add, remove\n") -def do_publish(skill_path: str, target: str = "github", repo: str = "", - console: Optional[Console] = None) -> None: +def do_publish(skill_path: str, target: str = "github", repo: str = "", console: Console | None = None) -> None: """Publish a local skill to a registry (GitHub PR or ClawHub submission).""" - from tools.skills_hub import GitHubAuth, SKILLS_DIR - from tools.skills_guard import scan_skill, format_scan_report + from tools.skills_guard import format_scan_report, scan_skill + from tools.skills_hub import SKILLS_DIR, GitHubAuth c = console or _console path = Path(skill_path) @@ -565,14 +579,16 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", # Validate the skill import yaml + skill_md = (path / "SKILL.md").read_text(encoding="utf-8") fm = {} if skill_md.startswith("---"): import re - match = re.search(r'\n---\s*\n', skill_md[3:]) + + match = re.search(r"\n---\s*\n", skill_md[3:]) if match: try: - fm = yaml.safe_load(skill_md[3:match.start() + 3]) or {} + fm = yaml.safe_load(skill_md[3 : match.start() + 3]) or {} except yaml.YAMLError: pass @@ -592,14 +608,18 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", if target == "github": if not repo: - c.print("[bold red]Error:[/] --repo required for GitHub publish.\n" - "Usage: hermes skills publish --to github --repo owner/repo\n") + c.print( + "[bold red]Error:[/] --repo required for GitHub publish.\n" + "Usage: hermes skills publish --to github --repo owner/repo\n" + ) return auth = GitHubAuth() if not auth.is_authenticated(): - c.print("[bold red]Error:[/] GitHub authentication required.\n" - "Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n") + c.print( + "[bold red]Error:[/] GitHub authentication required.\n" + "Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n" + ) return c.print(f"[bold]Publishing '{name}' to {repo}...[/]") @@ -610,14 +630,12 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", c.print(f"[bold red]Error:[/] {msg}\n") elif target == "clawhub": - c.print("[yellow]ClawHub publishing is not yet supported. " - "Submit manually at https://clawhub.ai/submit[/]\n") + c.print("[yellow]ClawHub publishing is not yet supported. Submit manually at https://clawhub.ai/submit[/]\n") else: c.print(f"[bold red]Unknown target:[/] {target}. Use 'github' or 'clawhub'.\n") -def _github_publish(skill_path: Path, skill_name: str, target_repo: str, - auth) -> tuple: +def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -> tuple: """Create a PR to a GitHub repo with the skill. Returns (success, message).""" import httpx @@ -627,7 +645,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, try: resp = httpx.post( f"https://api.github.com/repos/{target_repo}/forks", - headers=headers, timeout=30, + headers=headers, + timeout=30, ) if resp.status_code in (200, 202): fork = resp.json() @@ -643,7 +662,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, try: resp = httpx.get( f"https://api.github.com/repos/{target_repo}", - headers=headers, timeout=15, + headers=headers, + timeout=15, ) default_branch = resp.json().get("default_branch", "main") except Exception: @@ -653,7 +673,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, try: resp = httpx.get( f"https://api.github.com/repos/{fork_repo}/git/refs/heads/{default_branch}", - headers=headers, timeout=15, + headers=headers, + timeout=15, ) base_sha = resp.json()["object"]["sha"] except Exception as e: @@ -664,7 +685,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, try: httpx.post( f"https://api.github.com/repos/{fork_repo}/git/refs", - headers=headers, timeout=15, + headers=headers, + timeout=15, json={"ref": f"refs/heads/{branch_name}", "sha": base_sha}, ) except Exception as e: @@ -678,10 +700,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, upload_path = f"skills/{skill_name}/{rel}" try: import base64 + content_b64 = base64.b64encode(f.read_bytes()).decode() httpx.put( f"https://api.github.com/repos/{fork_repo}/contents/{upload_path}", - headers=headers, timeout=15, + headers=headers, + timeout=15, json={ "message": f"Add {skill_name} skill: {rel}", "content": content_b64, @@ -695,11 +719,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, try: resp = httpx.post( f"https://api.github.com/repos/{target_repo}/pulls", - headers=headers, timeout=15, + headers=headers, + timeout=15, json={ "title": f"Add skill: {skill_name}", "body": f"Submitting the `{skill_name}` skill via Hermes Skills Hub.\n\n" - f"This skill was scanned by the Hermes Skills Guard before submission.", + f"This skill was scanned by the Hermes Skills Guard before submission.", "head": f"{fork_repo.split('/')[0]}:{branch_name}", "base": default_branch, }, @@ -713,7 +738,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str, return False, f"Network error creating PR: {e}" -def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> None: +def do_snapshot_export(output_path: str, console: Console | None = None) -> None: """Export current hub skill configuration to a portable JSON file.""" from tools.skills_hub import HubLockFile, TapsManager @@ -726,16 +751,15 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N snapshot = { "hermes_version": "0.1.0", - "exported_at": __import__("datetime").datetime.now( - __import__("datetime").timezone.utc - ).isoformat(), + "exported_at": __import__("datetime").datetime.now(__import__("datetime").timezone.utc).isoformat(), "skills": [ { "name": entry["name"], "source": entry.get("source", ""), "identifier": entry.get("identifier", ""), "category": str(Path(entry.get("install_path", "")).parent) - if "/" in entry.get("install_path", "") else "", + if "/" in entry.get("install_path", "") + else "", } for entry in installed ], @@ -748,8 +772,7 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n") -def do_snapshot_import(input_path: str, force: bool = False, - console: Optional[Console] = None) -> None: +def do_snapshot_import(input_path: str, force: bool = False, console: Console | None = None) -> None: """Re-install skills from a snapshot file.""" from tools.skills_hub import TapsManager @@ -799,6 +822,7 @@ def do_snapshot_import(input_path: str, force: bool = False, # CLI argparse entry point # --------------------------------------------------------------------------- + def skills_command(args) -> None: """Router for `hermes skills ` — called from hermes_cli/main.py.""" action = getattr(args, "skills_action", None) @@ -839,7 +863,9 @@ def skills_command(args) -> None: return do_tap(tap_action, repo=repo) else: - _console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n") + _console.print( + "Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n" + ) _console.print("Run 'hermes skills --help' for details.\n") @@ -847,7 +873,8 @@ def skills_command(args) -> None: # Slash command entry point (/skills in chat) # --------------------------------------------------------------------------- -def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: + +def handle_skills_slash(cmd: str, console: Console | None = None) -> None: """ Parse and dispatch `/skills [args]` from the chat interface. @@ -1008,17 +1035,19 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: def _print_skills_help(console: Console) -> None: """Print help for the /skills slash command.""" - console.print(Panel( - "[bold]Skills Hub Commands:[/]\n\n" - " [cyan]browse[/] [--source official] Browse all available skills (paginated)\n" - " [cyan]search[/] Search registries for skills\n" - " [cyan]install[/] Install a skill (with security scan)\n" - " [cyan]inspect[/] Preview a skill without installing\n" - " [cyan]list[/] [--source hub|builtin] List installed skills\n" - " [cyan]audit[/] [name] Re-scan hub skills for security\n" - " [cyan]uninstall[/] Remove a hub-installed skill\n" - " [cyan]publish[/] --repo Publish a skill to GitHub via PR\n" - " [cyan]snapshot[/] export|import Export/import skill configurations\n" - " [cyan]tap[/] list|add|remove Manage skill sources\n", - title="/skills", - )) + console.print( + Panel( + "[bold]Skills Hub Commands:[/]\n\n" + " [cyan]browse[/] [--source official] Browse all available skills (paginated)\n" + " [cyan]search[/] Search registries for skills\n" + " [cyan]install[/] Install a skill (with security scan)\n" + " [cyan]inspect[/] Preview a skill without installing\n" + " [cyan]list[/] [--source hub|builtin] List installed skills\n" + " [cyan]audit[/] [name] Re-scan hub skills for security\n" + " [cyan]uninstall[/] Remove a hub-installed skill\n" + " [cyan]publish[/] --repo Publish a skill to GitHub via PR\n" + " [cyan]snapshot[/] export|import Export/import skill configurations\n" + " [cyan]tap[/] list|add|remove Manage skill sources\n", + title="/skills", + ) + ) diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 12b064fea6..a47481809a 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -5,21 +5,25 @@ Shows the status of all Hermes Agent components. """ import os -import sys import subprocess +import sys from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() +from datetime import UTC + from hermes_cli.colors import Colors, color from hermes_cli.config import get_env_path, get_env_value from hermes_constants import OPENROUTER_MODELS_URL + def check_mark(ok: bool) -> str: if ok: return color("✓", Colors.GREEN) return color("✗", Colors.RED) + def redact_key(key: str) -> str: """Redact an API key for display.""" if not key: @@ -33,7 +37,8 @@ def _format_iso_timestamp(value) -> str: """Format ISO timestamps for status output, converting to local timezone.""" if not value or not isinstance(value, str): return "(unknown)" - from datetime import datetime, timezone + from datetime import datetime + text = value.strip() if not text: return "(unknown)" @@ -42,7 +47,7 @@ def _format_iso_timestamp(value) -> str: try: parsed = datetime.fromisoformat(text) if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) + parsed = parsed.replace(tzinfo=UTC) except Exception: return value return parsed.astimezone().strftime("%Y-%m-%d %H:%M:%S %Z") @@ -50,14 +55,14 @@ def _format_iso_timestamp(value) -> str: def show_status(args): """Show status of all Hermes Agent components.""" - show_all = getattr(args, 'all', False) - deep = getattr(args, 'deep', False) - + show_all = getattr(args, "all", False) + deep = getattr(args, "deep", False) + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN)) print(color("│ ⚕ Hermes Agent Status │", Colors.CYAN)) print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN)) - + # ========================================================================= # Environment # ========================================================================= @@ -65,19 +70,19 @@ def show_status(args): print(color("◆ Environment", Colors.CYAN, Colors.BOLD)) print(f" Project: {PROJECT_ROOT}") print(f" Python: {sys.version.split()[0]}") - + env_path = get_env_path() print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}") - + # ========================================================================= # API Keys # ========================================================================= print() print(color("◆ API Keys", Colors.CYAN, Colors.BOLD)) - + keys = { "OpenRouter": "OPENROUTER_API_KEY", - "Anthropic": "ANTHROPIC_API_KEY", + "Anthropic": "ANTHROPIC_API_KEY", "OpenAI": "OPENAI_API_KEY", "Z.AI/GLM": "GLM_API_KEY", "Kimi": "KIMI_API_KEY", @@ -91,7 +96,7 @@ def show_status(args): "ElevenLabs": "ELEVENLABS_API_KEY", "GitHub": "GITHUB_TOKEN", } - + for name, env_var in keys.items(): value = get_env_value(env_var) or "" has_key = bool(value) @@ -105,7 +110,8 @@ def show_status(args): print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD)) try: - from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status + from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status + nous_status = get_nous_auth_status() codex_status = get_codex_auth_status() except Exception: @@ -148,10 +154,10 @@ def show_status(args): print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD)) apikey_providers = { - "Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), - "Kimi / Moonshot": ("KIMI_API_KEY",), - "MiniMax": ("MINIMAX_API_KEY",), - "MiniMax (China)": ("MINIMAX_CN_API_KEY",), + "Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), + "Kimi / Moonshot": ("KIMI_API_KEY",), + "MiniMax": ("MINIMAX_API_KEY",), + "MiniMax (China)": ("MINIMAX_CN_API_KEY",), } for pname, env_vars in apikey_providers.items(): key_val = "" @@ -168,19 +174,20 @@ def show_status(args): # ========================================================================= print() print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD)) - + terminal_env = os.getenv("TERMINAL_ENV", "") if not terminal_env: # Fall back to config file value when env var isn't set # (hermes status doesn't go through cli.py's config loading) try: from hermes_cli.config import load_config + _cfg = load_config() terminal_env = _cfg.get("terminal", {}).get("backend", "local") except Exception: terminal_env = "local" print(f" Backend: {terminal_env}") - + if terminal_env == "ssh": ssh_host = os.getenv("TERMINAL_SSH_HOST", "") ssh_user = os.getenv("TERMINAL_SSH_USER", "") @@ -192,16 +199,16 @@ def show_status(args): elif terminal_env == "daytona": daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20") print(f" Daytona Image: {daytona_image}") - + sudo_password = os.getenv("SUDO_PASSWORD", "") print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}") - + # ========================================================================= # Messaging Platforms # ========================================================================= print() print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD)) - + platforms = { "Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"), "Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"), @@ -209,59 +216,52 @@ def show_status(args): "Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"), "Slack": ("SLACK_BOT_TOKEN", None), } - + for name, (token_var, home_var) in platforms.items(): token = os.getenv(token_var, "") has_token = bool(token) - + home_channel = "" if home_var: home_channel = os.getenv(home_var, "") - + status = "configured" if has_token else "not configured" if home_channel: status += f" (home: {home_channel})" - + print(f" {name:<12} {check_mark(has_token)} {status}") - + # ========================================================================= # Gateway Status # ========================================================================= print() print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD)) - - if sys.platform.startswith('linux'): - result = subprocess.run( - ["systemctl", "--user", "is-active", "hermes-gateway"], - capture_output=True, - text=True - ) + + if sys.platform.startswith("linux"): + result = subprocess.run(["systemctl", "--user", "is-active", "hermes-gateway"], capture_output=True, text=True) is_active = result.stdout.strip() == "active" print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}") - print(f" Manager: systemd (user)") - - elif sys.platform == 'darwin': - result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], - capture_output=True, - text=True - ) + print(" Manager: systemd (user)") + + elif sys.platform == "darwin": + result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True) is_loaded = result.returncode == 0 print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}") - print(f" Manager: launchd") + print(" Manager: launchd") else: print(f" Status: {color('N/A', Colors.DIM)}") - print(f" Manager: (not supported on this platform)") - + print(" Manager: (not supported on this platform)") + # ========================================================================= # Cron Jobs # ========================================================================= print() print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD)) - + jobs_file = Path.home() / ".hermes" / "cron" / "jobs.json" if jobs_file.exists(): import json + try: with open(jobs_file) as f: data = json.load(f) @@ -269,56 +269,57 @@ def show_status(args): enabled_jobs = [j for j in jobs if j.get("enabled", True)] print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total") except Exception: - print(f" Jobs: (error reading jobs file)") + print(" Jobs: (error reading jobs file)") else: - print(f" Jobs: 0") - + print(" Jobs: 0") + # ========================================================================= # Sessions # ========================================================================= print() print(color("◆ Sessions", Colors.CYAN, Colors.BOLD)) - + sessions_file = Path.home() / ".hermes" / "sessions" / "sessions.json" if sessions_file.exists(): import json + try: with open(sessions_file) as f: data = json.load(f) print(f" Active: {len(data)} session(s)") except Exception: - print(f" Active: (error reading sessions file)") + print(" Active: (error reading sessions file)") else: - print(f" Active: 0") - + print(" Active: 0") + # ========================================================================= # Deep checks # ========================================================================= if deep: print() print(color("◆ Deep Checks", Colors.CYAN, Colors.BOLD)) - + # Check OpenRouter connectivity openrouter_key = os.getenv("OPENROUTER_API_KEY", "") if openrouter_key: try: import httpx + response = httpx.get( - OPENROUTER_MODELS_URL, - headers={"Authorization": f"Bearer {openrouter_key}"}, - timeout=10 + OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10 ) ok = response.status_code == 200 print(f" OpenRouter: {check_mark(ok)} {'reachable' if ok else f'error ({response.status_code})'}") except Exception as e: print(f" OpenRouter: {check_mark(False)} error: {e}") - + # Check gateway port try: import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) - result = sock.connect_ex(('127.0.0.1', 18789)) + result = sock.connect_ex(("127.0.0.1", 18789)) sock.close() # Port in use = gateway likely running port_in_use = result == 0 @@ -326,7 +327,7 @@ def show_status(args): print(f" Port 18789: {'in use' if port_in_use else 'available'}") except OSError: pass - + print() print(color("─" * 60, Colors.DIM)) print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM)) diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 19288bf59f..56bc6ccbb8 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -11,33 +11,37 @@ the `platform_toolsets` key. import sys from pathlib import Path -from typing import Dict, List, Set -import os - -from hermes_cli.config import ( - load_config, save_config, get_env_value, save_env_value, - get_hermes_home, -) from hermes_cli.colors import Colors, color +from hermes_cli.config import ( + get_env_value, + load_config, + save_config, + save_env_value, +) PROJECT_ROOT = Path(__file__).parent.parent.resolve() # ─── UI Helpers (shared with setup.py) ──────────────────────────────────────── + def _print_info(text: str): print(color(f" {text}", Colors.DIM)) + def _print_success(text: str): print(color(f"✓ {text}", Colors.GREEN)) + def _print_warning(text: str): print(color(f"⚠ {text}", Colors.YELLOW)) + def _print_error(text: str): print(color(f"✗ {text}", Colors.RED)) + def _prompt(question: str, default: str = None, password: bool = False) -> str: if default: display = f"{question} [{default}]: " @@ -46,6 +50,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str: try: if password: import getpass + value = getpass.getpass(color(display, Colors.YELLOW)) else: value = input(color(display, Colors.YELLOW)) @@ -54,6 +59,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str: print() return default or "" + def _prompt_yes_no(question: str, default: bool = True) -> bool: default_str = "Y/n" if default else "y/N" while True: @@ -64,9 +70,9 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool: return default if not value: return default - if value in ('y', 'yes'): + if value in ("y", "yes"): return True - if value in ('n', 'no'): + if value in ("n", "no"): return False @@ -76,24 +82,24 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool: # Each entry: (toolset_name, label, description) # These map to keys in toolsets.py TOOLSETS dict. CONFIGURABLE_TOOLSETS = [ - ("web", "🔍 Web Search & Scraping", "web_search, web_extract"), - ("browser", "🌐 Browser Automation", "navigate, click, type, scroll"), - ("terminal", "💻 Terminal & Processes", "terminal, process"), - ("file", "📁 File Operations", "read, write, patch, search"), - ("code_execution", "⚡ Code Execution", "execute_code"), - ("vision", "👁️ Vision / Image Analysis", "vision_analyze"), - ("image_gen", "🎨 Image Generation", "image_generate"), - ("moa", "🧠 Mixture of Agents", "mixture_of_agents"), - ("tts", "🔊 Text-to-Speech", "text_to_speech"), - ("skills", "📚 Skills", "list, view, manage"), - ("todo", "📋 Task Planning", "todo"), - ("memory", "💾 Memory", "persistent memory across sessions"), - ("session_search", "🔎 Session Search", "search past conversations"), - ("clarify", "❓ Clarifying Questions", "clarify"), - ("delegation", "👥 Task Delegation", "delegate_task"), - ("cronjob", "⏰ Cron Jobs", "schedule, list, remove"), - ("rl", "🧪 RL Training", "Tinker-Atropos training tools"), - ("homeassistant", "🏠 Home Assistant", "smart home device control"), + ("web", "🔍 Web Search & Scraping", "web_search, web_extract"), + ("browser", "🌐 Browser Automation", "navigate, click, type, scroll"), + ("terminal", "💻 Terminal & Processes", "terminal, process"), + ("file", "📁 File Operations", "read, write, patch, search"), + ("code_execution", "⚡ Code Execution", "execute_code"), + ("vision", "👁️ Vision / Image Analysis", "vision_analyze"), + ("image_gen", "🎨 Image Generation", "image_generate"), + ("moa", "🧠 Mixture of Agents", "mixture_of_agents"), + ("tts", "🔊 Text-to-Speech", "text_to_speech"), + ("skills", "📚 Skills", "list, view, manage"), + ("todo", "📋 Task Planning", "todo"), + ("memory", "💾 Memory", "persistent memory across sessions"), + ("session_search", "🔎 Session Search", "search past conversations"), + ("clarify", "❓ Clarifying Questions", "clarify"), + ("delegation", "👥 Task Delegation", "delegate_task"), + ("cronjob", "⏰ Cron Jobs", "schedule, list, remove"), + ("rl", "🧪 RL Training", "Tinker-Atropos training tools"), + ("homeassistant", "🏠 Home Assistant", "smart home device control"), ] # Toolsets that are OFF by default for new installs. @@ -103,11 +109,11 @@ _DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"} # Platform display config PLATFORMS = { - "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, - "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, - "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"}, - "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, - "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, + "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, + "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, + "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"}, + "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, + "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, } @@ -131,7 +137,11 @@ TOOL_CATEGORIES = { "name": "OpenAI TTS", "tag": "Premium - high quality voices", "env_vars": [ - {"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"}, + { + "key": "VOICE_TOOLS_OPENAI_KEY", + "prompt": "OpenAI API key", + "url": "https://platform.openai.com/api-keys", + }, ], "tts_provider": "openai", }, @@ -139,7 +149,11 @@ TOOL_CATEGORIES = { "name": "ElevenLabs", "tag": "Premium - most natural voices", "env_vars": [ - {"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"}, + { + "key": "ELEVENLABS_API_KEY", + "prompt": "ElevenLabs API key", + "url": "https://elevenlabs.io/app/settings/api-keys", + }, ], "tts_provider": "elevenlabs", }, @@ -224,7 +238,11 @@ TOOL_CATEGORIES = { "name": "Tinker / Atropos", "tag": "RL training platform", "env_vars": [ - {"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"}, + { + "key": "TINKER_API_KEY", + "prompt": "Tinker API key", + "url": "https://tinker-console.thinkingmachines.ai/keys", + }, {"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"}, ], "post_setup": "rl_training", @@ -236,24 +254,26 @@ TOOL_CATEGORIES = { # Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES. # Used as a fallback for tools like vision/moa that just need an API key. TOOLSET_ENV_REQUIREMENTS = { - "vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")], - "moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")], + "vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")], + "moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")], } # ─── Post-Setup Hooks ───────────────────────────────────────────────────────── + def _run_post_setup(post_setup_key: str): """Run post-setup hooks for tools that need extra installation steps.""" import shutil + if post_setup_key == "browserbase": node_modules = PROJECT_ROOT / "node_modules" / "agent-browser" if not node_modules.exists() and shutil.which("npm"): _print_info(" Installing Node.js dependencies for browser tools...") import subprocess + result = subprocess.run( - ["npm", "install", "--silent"], - capture_output=True, text=True, cwd=str(PROJECT_ROOT) + ["npm", "install", "--silent"], capture_output=True, text=True, cwd=str(PROJECT_ROOT) ) if result.returncode == 0: _print_success(" Node.js dependencies installed") @@ -270,16 +290,17 @@ def _run_post_setup(post_setup_key: str): if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists(): _print_info(" Installing tinker-atropos submodule...") import subprocess + uv_bin = shutil.which("uv") if uv_bin: result = subprocess.run( [uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)], - capture_output=True, text=True + capture_output=True, + text=True, ) else: result = subprocess.run( - [sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)], - capture_output=True, text=True + [sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)], capture_output=True, text=True ) if result.returncode == 0: _print_success(" tinker-atropos installed") @@ -294,7 +315,8 @@ def _run_post_setup(post_setup_key: str): # ─── Platform / Toolset Helpers ─────────────────────────────────────────────── -def _get_enabled_platforms() -> List[str]: + +def _get_enabled_platforms() -> list[str]: """Return platform keys that are configured (have tokens or are CLI).""" enabled = ["cli"] if get_env_value("TELEGRAM_BOT_TOKEN"): @@ -308,9 +330,9 @@ def _get_enabled_platforms() -> List[str]: return enabled -def _get_platform_tools(config: dict, platform: str) -> Set[str]: +def _get_platform_tools(config: dict, platform: str) -> set[str]: """Resolve which individual toolset names are enabled for a platform.""" - from toolsets import resolve_toolset, TOOLSETS + from toolsets import resolve_toolset platform_toolsets = config.get("platform_toolsets", {}) toolset_names = platform_toolsets.get(platform) @@ -335,7 +357,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]: return enabled_toolsets -def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]): +def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: set[str]): """Save the selected toolset keys for a platform to config.""" config.setdefault("platform_toolsets", {}) config["platform_toolsets"][platform] = sorted(enabled_toolset_keys) @@ -364,6 +386,7 @@ def _toolset_has_keys(ts_key: str) -> bool: # ─── Menu Helpers ───────────────────────────────────────────────────────────── + def _prompt_choice(question: str, choices: list, default: int = 0) -> int: """Single-select menu (arrow keys). Uses curses to avoid simple_term_menu rendering bugs in tmux, iTerm, and other non-standard terminals.""" @@ -371,6 +394,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: # Curses-based single-select — works in tmux, iTerm, and standard terminals try: import curses + result_holder = [default] def _curses_menu(stdscr): @@ -386,8 +410,9 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: stdscr.clear() max_y, max_x = stdscr.getmaxyx() try: - stdscr.addnstr(0, 0, question, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) + stdscr.addnstr( + 0, 0, question, max_x - 1, curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0) + ) except curses.error: pass @@ -410,14 +435,14 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: stdscr.refresh() key = stdscr.getch() - if key in (curses.KEY_UP, ord('k')): + if key in (curses.KEY_UP, ord("k")): cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord('j')): + elif key in (curses.KEY_DOWN, ord("j")): cursor = (cursor + 1) % len(choices) elif key in (curses.KEY_ENTER, 10, 13): result_holder[0] = cursor return - elif key in (27, ord('q')): + elif key in (27, ord("q")): return curses.wrapper(_curses_menu) @@ -431,7 +456,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: for i, c in enumerate(choices): marker = "●" if i == default else "○" style = Colors.GREEN if i == default else "" - print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}") + print(color(f" {marker} {i + 1}. {c}", style) if style else f" {marker} {i + 1}. {c}") while True: try: val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM)) @@ -445,7 +470,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: return default -def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]: +def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str]: """Multi-select checklist of toolsets. Returns set of selected toolset keys.""" labels = [] @@ -455,15 +480,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str suffix = " [no API key]" labels.append(f"{ts_label} ({ts_desc}){suffix}") - pre_selected_indices = [ - i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS) - if ts_key in enabled - ] + pre_selected_indices = [i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS) if ts_key in enabled] # Curses-based multi-select — arrow keys + space to toggle + enter to confirm. # simple_term_menu has rendering bugs in tmux, iTerm, and other terminals. try: import curses + selected = set(pre_selected_indices) result_holder = [None] @@ -483,7 +506,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str max_y, max_x = stdscr.getmaxyx() header = f"Tools for {platform_label} — ↑↓ navigate, SPACE toggle, ENTER confirm" try: - stdscr.addnstr(0, 0, header, max_x - 1, curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD) + stdscr.addnstr( + 0, + 0, + header, + max_x - 1, + curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD, + ) except curses.error: pass @@ -514,11 +543,11 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str stdscr.refresh() key = stdscr.getch() - if key in (curses.KEY_UP, ord('k')): + if key in (curses.KEY_UP, ord("k")): cursor = (cursor - 1) % len(labels) - elif key in (curses.KEY_DOWN, ord('j')): + elif key in (curses.KEY_DOWN, ord("j")): cursor = (cursor + 1) % len(labels) - elif key == ord(' '): + elif key == ord(" "): if cursor in selected: selected.discard(cursor) else: @@ -526,7 +555,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str elif key in (curses.KEY_ENTER, 10, 13): result_holder[0] = {CONFIGURABLE_TOOLSETS[i][0] for i in selected} return - elif key in (27, ord('q')): # ESC or q + elif key in (27, ord("q")): # ESC or q result_holder[0] = enabled return @@ -565,9 +594,10 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str # ─── Provider-Aware Configuration ──────────────────────────────────────────── + def _configure_toolset(ts_key: str, config: dict): """Configure a toolset - provider selection + API keys. - + Uses TOOL_CATEGORIES for provider-aware config, falls back to simple env var prompts for toolsets not in TOOL_CATEGORIES. """ @@ -591,7 +621,9 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): req = cat["requires_python"] if sys.version_info < req: print() - _print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})") + _print_error( + f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})" + ) _print_info(" Upgrade Python and reinstall to enable this tool.") return @@ -610,7 +642,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): # Multiple providers - let user choose print() # Use custom title if provided (e.g. "Select Search Provider") - title = cat.get("setup_title", f"Choose a provider") + title = cat.get("setup_title", "Choose a provider") print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN)) if cat.get("setup_note"): _print_info(f" {cat['setup_note']}") @@ -626,7 +658,11 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]: configured = " [active]" elif not env_vars: - configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else "" + configured = ( + " [active]" + if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") + else "" + ) else: configured = " [configured]" provider_choices.append(f"{p['name']}{tag}{configured}") @@ -688,9 +724,9 @@ def _configure_provider(provider: dict, config: dict): if value: save_env_value(var["key"], value) - _print_success(f" Saved") + _print_success(" Saved") else: - _print_warning(f" Skipped") + _print_warning(" Skipped") all_configured = False # Run post-setup hooks if needed @@ -721,9 +757,9 @@ def _configure_simple_requirements(ts_key: str): value = _prompt(f" {var}", password=True) if value and value.strip(): save_env_value(var, value.strip()) - _print_success(f" Saved") + _print_success(" Saved") else: - _print_warning(f" Skipped") + _print_warning(" Skipped") def _reconfigure_tool(config: dict): @@ -827,9 +863,9 @@ def _reconfigure_provider(provider: dict, config: dict): value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val) if value and value.strip(): save_env_value(var["key"], value.strip()) - _print_success(f" Updated") + _print_success(" Updated") else: - _print_info(f" Kept current") + _print_info(" Kept current") def _reconfigure_simple_requirements(ts_key: str): @@ -851,13 +887,14 @@ def _reconfigure_simple_requirements(ts_key: str): value = _prompt(f" {var} (Enter to keep current)", password=True) if value and value.strip(): save_env_value(var, value.strip()) - _print_success(f" Updated") + _print_success(" Updated") else: - _print_info(f" Kept current") + _print_info(" Kept current") # ─── Main Entry Point ───────────────────────────────────────────────────────── + def tools_command(args=None, first_install: bool = False, config: dict = None): """Entry point for `hermes tools` and `hermes setup tools`. @@ -907,7 +944,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): # TTS (Edge vs OpenAI vs ElevenLabs), etc. are shown even when # a free provider exists. to_configure = [ - ts_key for ts_key in sorted(new_enabled) + ts_key + for ts_key in sorted(new_enabled) if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key) ] @@ -981,7 +1019,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): # Configure newly enabled toolsets that need API keys for ts_key in sorted(added): - if (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)): + if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key): if not _toolset_has_keys(ts_key): _configure_toolset(ts_key, config) diff --git a/hermes_cli/uninstall.py b/hermes_cli/uninstall.py index d70405ce31..723c788d1f 100644 --- a/hermes_cli/uninstall.py +++ b/hermes_cli/uninstall.py @@ -7,23 +7,25 @@ Provides options for: """ import os -import sys import shutil import subprocess from pathlib import Path -from typing import Optional from hermes_cli.colors import Colors, color + def log_info(msg: str): print(f"{color('→', Colors.CYAN)} {msg}") + def log_success(msg: str): print(f"{color('✓', Colors.GREEN)} {msg}") + def log_warn(msg: str): print(f"{color('⚠', Colors.YELLOW)} {msg}") + def log_error(msg: str): print(f"{color('✗', Colors.RED)} {msg}") @@ -42,7 +44,7 @@ def find_shell_configs() -> list: """Find shell configuration files that might have PATH entries.""" home = Path.home() configs = [] - + candidates = [ home / ".bashrc", home / ".bash_profile", @@ -50,11 +52,11 @@ def find_shell_configs() -> list: home / ".zshrc", home / ".zprofile", ] - + for config in candidates: if config.exists(): configs.append(config) - + return configs @@ -62,45 +64,45 @@ def remove_path_from_shell_configs(): """Remove Hermes PATH entries from shell configuration files.""" configs = find_shell_configs() removed_from = [] - + for config_path in configs: try: content = config_path.read_text() original_content = content - + # Remove lines containing hermes-agent or hermes PATH entries new_lines = [] skip_next = False - - for line in content.split('\n'): + + for line in content.split("\n"): # Skip the "# Hermes Agent" comment and following line - if '# Hermes Agent' in line or '# hermes-agent' in line: + if "# Hermes Agent" in line or "# hermes-agent" in line: skip_next = True continue - if skip_next and ('hermes' in line.lower() and 'PATH' in line): + if skip_next and ("hermes" in line.lower() and "PATH" in line): skip_next = False continue skip_next = False - + # Remove any PATH line containing hermes - if 'hermes' in line.lower() and ('PATH=' in line or 'path=' in line.lower()): + if "hermes" in line.lower() and ("PATH=" in line or "path=" in line.lower()): continue - + new_lines.append(line) - - new_content = '\n'.join(new_lines) - + + new_content = "\n".join(new_lines) + # Clean up multiple blank lines - while '\n\n\n' in new_content: - new_content = new_content.replace('\n\n\n', '\n\n') - + while "\n\n\n" in new_content: + new_content = new_content.replace("\n\n\n", "\n\n") + if new_content != original_content: config_path.write_text(new_content) removed_from.append(config_path) - + except Exception as e: log_warn(f"Could not update {config_path}: {e}") - + return removed_from @@ -110,61 +112,49 @@ def remove_wrapper_script(): Path.home() / ".local" / "bin" / "hermes", Path("/usr/local/bin/hermes"), ] - + removed = [] for wrapper in wrapper_paths: if wrapper.exists(): try: # Check if it's our wrapper (contains hermes_cli reference) content = wrapper.read_text() - if 'hermes_cli' in content or 'hermes-agent' in content: + if "hermes_cli" in content or "hermes-agent" in content: wrapper.unlink() removed.append(wrapper) except Exception as e: log_warn(f"Could not remove {wrapper}: {e}") - + return removed def uninstall_gateway_service(): """Stop and uninstall the gateway service if running.""" import platform - + if platform.system() != "Linux": return False - + service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service" - + if not service_file.exists(): return False - + try: # Stop the service - subprocess.run( - ["systemctl", "--user", "stop", "hermes-gateway"], - capture_output=True, - check=False - ) - + subprocess.run(["systemctl", "--user", "stop", "hermes-gateway"], capture_output=True, check=False) + # Disable the service - subprocess.run( - ["systemctl", "--user", "disable", "hermes-gateway"], - capture_output=True, - check=False - ) - + subprocess.run(["systemctl", "--user", "disable", "hermes-gateway"], capture_output=True, check=False) + # Remove service file service_file.unlink() - + # Reload systemd - subprocess.run( - ["systemctl", "--user", "daemon-reload"], - capture_output=True, - check=False - ) - + subprocess.run(["systemctl", "--user", "daemon-reload"], capture_output=True, check=False) + return True - + except Exception as e: log_warn(f"Could not fully remove gateway service: {e}") return False @@ -173,20 +163,20 @@ def uninstall_gateway_service(): def run_uninstall(args): """ Run the uninstall process. - + Options: - Full uninstall: removes code + ~/.hermes/ (configs, data, logs) - Keep data: removes code but keeps ~/.hermes/ for future reinstall """ project_root = get_project_root() hermes_home = get_hermes_home() - + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA, Colors.BOLD)) print(color("│ ⚕ Hermes Agent Uninstaller │", Colors.MAGENTA, Colors.BOLD)) print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA, Colors.BOLD)) print() - + # Show what will be affected print(color("Current Installation:", Colors.CYAN, Colors.BOLD)) print(f" Code: {project_root}") @@ -194,7 +184,7 @@ def run_uninstall(args): print(f" Secrets: {hermes_home / '.env'}") print(f" Data: {hermes_home / 'cron/'}, {hermes_home / 'sessions/'}, {hermes_home / 'logs/'}") print() - + # Ask for confirmation print(color("Uninstall Options:", Colors.YELLOW, Colors.BOLD)) print() @@ -206,21 +196,21 @@ def run_uninstall(args): print() print(" 3) " + color("Cancel", Colors.CYAN) + " - Don't uninstall") print() - + try: choice = input(color("Select option [1/2/3]: ", Colors.BOLD)).strip() except (KeyboardInterrupt, EOFError): print() print("Cancelled.") return - + if choice == "3" or choice.lower() in ("c", "cancel", "q", "quit", "n", "no"): print() print("Uninstall cancelled.") return - - full_uninstall = (choice == "2") - + + full_uninstall = choice == "2" + # Final confirmation print() if full_uninstall: @@ -228,7 +218,7 @@ def run_uninstall(args): print(color(" Including: configs, API keys, sessions, scheduled jobs, logs", Colors.RED)) else: print("This will remove the Hermes code but keep your configuration and data.") - + print() try: confirm = input(f"Type '{color('yes', Colors.YELLOW)}' to confirm: ").strip().lower() @@ -236,23 +226,23 @@ def run_uninstall(args): print() print("Cancelled.") return - + if confirm != "yes": print() print("Uninstall cancelled.") return - + print() print(color("Uninstalling...", Colors.CYAN, Colors.BOLD)) print() - + # 1. Stop and uninstall gateway service log_info("Checking for gateway service...") if uninstall_gateway_service(): log_success("Gateway service stopped and removed") else: log_info("No gateway service found") - + # 2. Remove PATH entries from shell configs log_info("Removing PATH entries from shell configs...") removed_configs = remove_path_from_shell_configs() @@ -261,7 +251,7 @@ def run_uninstall(args): log_success(f"Updated {config}") else: log_info("No PATH entries found to remove") - + # 3. Remove wrapper script log_info("Removing hermes command...") removed_wrappers = remove_wrapper_script() @@ -270,10 +260,10 @@ def run_uninstall(args): log_success(f"Removed {wrapper}") else: log_info("No wrapper script found") - + # 4. Remove installation directory (code) - log_info(f"Removing installation directory...") - + log_info("Removing installation directory...") + # Check if we're running from within the install dir # We need to be careful here try: @@ -289,7 +279,7 @@ def run_uninstall(args): except Exception as e: log_warn(f"Could not fully remove {project_root}: {e}") log_info("You may need to manually remove it") - + # 5. Optionally remove ~/.hermes/ data directory if full_uninstall: log_info("Removing configuration and data...") @@ -302,22 +292,27 @@ def run_uninstall(args): log_info("You may need to manually remove it") else: log_info(f"Keeping configuration and data in {hermes_home}") - + # Done print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN, Colors.BOLD)) print(color("│ ✓ Uninstall Complete! │", Colors.GREEN, Colors.BOLD)) print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN, Colors.BOLD)) print() - + if not full_uninstall: print(color("Your configuration and data have been preserved:", Colors.CYAN)) print(f" {hermes_home}/") print() print("To reinstall later with your existing settings:") - print(color(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", Colors.DIM)) + print( + color( + " curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", + Colors.DIM, + ) + ) print() - + print(color("Reload your shell to complete the process:", Colors.YELLOW)) print(" source ~/.bashrc # or ~/.zshrc") print() diff --git a/hermes_state.py b/hermes_state.py index 67b4484e73..6f24336a6e 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -19,8 +19,7 @@ import os import sqlite3 import time from pathlib import Path -from typing import Dict, Any, List, Optional - +from typing import Any DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db" @@ -156,8 +155,7 @@ class SessionDB: # since the title column is guaranteed to exist at this point) try: cursor.execute( - "CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique " - "ON sessions(title) WHERE title IS NOT NULL" + "CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique ON sessions(title) WHERE title IS NOT NULL" ) except sqlite3.OperationalError: pass # Index already exists @@ -185,7 +183,7 @@ class SessionDB: session_id: str, source: str, model: str = None, - model_config: Dict[str, Any] = None, + model_config: dict[str, Any] = None, system_prompt: str = None, user_id: str = None, parent_session_id: str = None, @@ -225,9 +223,7 @@ class SessionDB: ) self._conn.commit() - def update_token_counts( - self, session_id: str, input_tokens: int = 0, output_tokens: int = 0 - ) -> None: + def update_token_counts(self, session_id: str, input_tokens: int = 0, output_tokens: int = 0) -> None: """Increment token counters on a session.""" self._conn.execute( """UPDATE sessions SET @@ -238,11 +234,9 @@ class SessionDB: ) self._conn.commit() - def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: + def get_session(self, session_id: str) -> dict[str, Any] | None: """Get a session by ID.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE id = ?", (session_id,) - ) + cursor = self._conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,)) row = cursor.fetchone() return dict(row) if row else None @@ -250,7 +244,7 @@ class SessionDB: MAX_TITLE_LENGTH = 100 @staticmethod - def sanitize_title(title: Optional[str]) -> Optional[str]: + def sanitize_title(title: str | None) -> str | None: """Validate and sanitize a session title. - Strips leading/trailing whitespace @@ -271,27 +265,26 @@ class SessionDB: # Remove ASCII control characters (0x00-0x1F, 0x7F) but keep # whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be # normalized to spaces by the whitespace collapsing step below - cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title) + cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", title) # Remove problematic Unicode control characters: # - Zero-width chars (U+200B-U+200F, U+FEFF) # - Directional overrides (U+202A-U+202E, U+2066-U+2069) # - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB) cleaned = re.sub( - r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]', - '', cleaned, + r"[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]", + "", + cleaned, ) # Collapse internal whitespace runs and strip - cleaned = re.sub(r'\s+', ' ', cleaned).strip() + cleaned = re.sub(r"\s+", " ", cleaned).strip() if not cleaned: return None if len(cleaned) > SessionDB.MAX_TITLE_LENGTH: - raise ValueError( - f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})" - ) + raise ValueError(f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})") return cleaned @@ -312,9 +305,7 @@ class SessionDB: ) conflict = cursor.fetchone() if conflict: - raise ValueError( - f"Title '{title}' is already in use by session {conflict['id']}" - ) + raise ValueError(f"Title '{title}' is already in use by session {conflict['id']}") cursor = self._conn.execute( "UPDATE sessions SET title = ? WHERE id = ?", (title, session_id), @@ -322,23 +313,19 @@ class SessionDB: self._conn.commit() return cursor.rowcount > 0 - def get_session_title(self, session_id: str) -> Optional[str]: + def get_session_title(self, session_id: str) -> str | None: """Get the title for a session, or None.""" - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE id = ?", (session_id,) - ) + cursor = self._conn.execute("SELECT title FROM sessions WHERE id = ?", (session_id,)) row = cursor.fetchone() return row["title"] if row else None - def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: + def get_session_by_title(self, title: str) -> dict[str, Any] | None: """Look up a session by exact title. Returns session dict or None.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE title = ?", (title,) - ) + cursor = self._conn.execute("SELECT * FROM sessions WHERE title = ?", (title,)) row = cursor.fetchone() return dict(row) if row else None - def resolve_session_by_title(self, title: str) -> Optional[str]: + def resolve_session_by_title(self, title: str) -> str | None: """Resolve a title to a session ID, preferring the latest in a lineage. If the exact title exists, returns that session's ID. @@ -353,8 +340,7 @@ class SessionDB: # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") cursor = self._conn.execute( - "SELECT id, title, started_at FROM sessions " - "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", + "SELECT id, title, started_at FROM sessions WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", (f"{escaped} #%",), ) numbered = cursor.fetchall() @@ -373,8 +359,9 @@ class SessionDB: the highest existing number and increments. """ import re + # Strip existing #N suffix to find the true base - match = re.match(r'^(.*?) #(\d+)$', base_title) + match = re.match(r"^(.*?) #(\d+)$", base_title) if match: base = match.group(1) else: @@ -395,7 +382,7 @@ class SessionDB: # Find the highest number max_num = 1 # The unnumbered original counts as #1 for t in existing: - m = re.match(r'^.* #(\d+)$', t) + m = re.match(r"^.* #(\d+)$", t) if m: max_num = max(max_num, int(m.group(1))) @@ -406,7 +393,7 @@ class SessionDB: source: str = None, limit: int = 20, offset: int = 0, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """List sessions with preview (first user message) and last active timestamp. Returns dicts with keys: id, source, model, title, started_at, ended_at, @@ -506,7 +493,7 @@ class SessionDB: self._conn.commit() return msg_id - def get_messages(self, session_id: str) -> List[Dict[str, Any]]: + def get_messages(self, session_id: str) -> list[dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" cursor = self._conn.execute( "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", @@ -524,7 +511,7 @@ class SessionDB: result.append(msg) return result - def get_messages_as_conversation(self, session_id: str) -> List[Dict[str, Any]]: + def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]: """ Load messages in the OpenAI conversation format (role + content dicts). Used by the gateway to restore conversation history. @@ -556,11 +543,11 @@ class SessionDB: def search_messages( self, query: str, - source_filter: List[str] = None, - role_filter: List[str] = None, + source_filter: list[str] = None, + role_filter: list[str] = None, limit: int = 20, offset: int = 0, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Full-text search across session messages using FTS5. @@ -628,8 +615,7 @@ class SessionDB: (match["session_id"], match["id"], match["id"]), ) context_msgs = [ - {"role": r["role"], "content": (r["content"] or "")[:200]} - for r in ctx_cursor.fetchall() + {"role": r["role"], "content": (r["content"] or "")[:200]} for r in ctx_cursor.fetchall() ] match["context"] = context_msgs except Exception: @@ -645,7 +631,7 @@ class SessionDB: source: str = None, limit: int = 20, offset: int = 0, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """List sessions, optionally filtered by source.""" if source: cursor = self._conn.execute( @@ -666,9 +652,7 @@ class SessionDB: def session_count(self, source: str = None) -> int: """Count sessions, optionally filtered by source.""" if source: - cursor = self._conn.execute( - "SELECT COUNT(*) FROM sessions WHERE source = ?", (source,) - ) + cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)) else: cursor = self._conn.execute("SELECT COUNT(*) FROM sessions") return cursor.fetchone()[0] @@ -676,9 +660,7 @@ class SessionDB: def message_count(self, session_id: str = None) -> int: """Count messages, optionally for a specific session.""" if session_id: - cursor = self._conn.execute( - "SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,) - ) + cursor = self._conn.execute("SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)) else: cursor = self._conn.execute("SELECT COUNT(*) FROM messages") return cursor.fetchone()[0] @@ -687,7 +669,7 @@ class SessionDB: # Export and cleanup # ========================================================================= - def export_session(self, session_id: str) -> Optional[Dict[str, Any]]: + def export_session(self, session_id: str) -> dict[str, Any] | None: """Export a single session with all its messages as a dict.""" session = self.get_session(session_id) if not session: @@ -695,7 +677,7 @@ class SessionDB: messages = self.get_messages(session_id) return {**session, "messages": messages} - def export_all(self, source: str = None) -> List[Dict[str, Any]]: + def export_all(self, source: str = None) -> list[dict[str, Any]]: """ Export all sessions (with messages) as a list of dicts. Suitable for writing to a JSONL file for backup/analysis. @@ -709,9 +691,7 @@ class SessionDB: def clear_messages(self, session_id: str) -> None: """Delete all messages for a session and reset its counters.""" - self._conn.execute( - "DELETE FROM messages WHERE session_id = ?", (session_id,) - ) + self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) self._conn.execute( "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", (session_id,), @@ -720,9 +700,7 @@ class SessionDB: def delete_session(self, session_id: str) -> bool: """Delete a session and all its messages. Returns True if found.""" - cursor = self._conn.execute( - "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) - ) + cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)) if cursor.fetchone()[0] == 0: return False self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) @@ -736,6 +714,7 @@ class SessionDB: Only prunes ended sessions (not active ones). """ import time as _time + cutoff = _time.time() - (older_than_days * 86400) if source: diff --git a/model_tools.py b/model_tools.py index 97a96e7a14..59c398357a 100644 --- a/model_tools.py +++ b/model_tools.py @@ -20,11 +20,10 @@ Public API (signatures preserved from the original 2,400-line version): check_tool_availability(quiet) -> tuple """ -import json import asyncio -import os +import json import logging -from typing import Dict, Any, List, Optional, Tuple +from typing import Any from tools.registry import registry from toolsets import resolve_toolset, validate_toolset @@ -36,6 +35,7 @@ logger = logging.getLogger(__name__) # Async Bridging (single source of truth -- used by registry.dispatch too) # ============================================================================= + def _run_async(coro): """Run an async coroutine from a sync context. @@ -56,6 +56,7 @@ def _run_async(coro): if loop and loop.is_running(): import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: future = pool.submit(asyncio.run, coro) return future.result(timeout=300) @@ -66,6 +67,7 @@ def _run_async(coro): # Tool Discovery (importing each module triggers its registry.register calls) # ============================================================================= + def _discover_tools(): """Import all tool modules to trigger their registry.register() calls. @@ -97,6 +99,7 @@ def _discover_tools(): "tools.homeassistant_tool", ] import importlib + for mod_name in _modules: try: importlib.import_module(mod_name) @@ -109,6 +112,7 @@ _discover_tools() # MCP tool discovery (external MCP servers from config) try: from tools.mcp_tool import discover_mcp_tools + discover_mcp_tools() except Exception as e: logger.debug("MCP tool discovery failed: %s", e) @@ -118,13 +122,13 @@ except Exception as e: # Backward-compat constants (built once after discovery) # ============================================================================= -TOOL_TO_TOOLSET_MAP: Dict[str, str] = registry.get_tool_to_toolset_map() +TOOL_TO_TOOLSET_MAP: dict[str, str] = registry.get_tool_to_toolset_map() -TOOLSET_REQUIREMENTS: Dict[str, dict] = registry.get_toolset_requirements() +TOOLSET_REQUIREMENTS: dict[str, dict] = registry.get_toolset_requirements() # Resolved tool names from the last get_tool_definitions() call. # Used by code_execution_tool to know which tools are available in this session. -_last_resolved_tool_names: List[str] = [] +_last_resolved_tool_names: list[str] = [] # ============================================================================= @@ -139,18 +143,29 @@ _LEGACY_TOOLSET_MAP = { "image_tools": ["image_generate"], "skills_tools": ["skills_list", "skill_view", "skill_manage"], "browser_tools": [ - "browser_navigate", "browser_snapshot", "browser_click", - "browser_type", "browser_scroll", "browser_back", - "browser_press", "browser_close", "browser_get_images", - "browser_vision" + "browser_navigate", + "browser_snapshot", + "browser_click", + "browser_type", + "browser_scroll", + "browser_back", + "browser_press", + "browser_close", + "browser_get_images", + "browser_vision", ], "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], "rl_tools": [ - "rl_list_environments", "rl_select_environment", - "rl_get_current_config", "rl_edit_config", - "rl_start_training", "rl_check_status", - "rl_stop_training", "rl_get_results", - "rl_list_runs", "rl_test_inference" + "rl_list_environments", + "rl_select_environment", + "rl_get_current_config", + "rl_edit_config", + "rl_start_training", + "rl_check_status", + "rl_stop_training", + "rl_get_results", + "rl_list_runs", + "rl_test_inference", ], "file_tools": ["read_file", "write_file", "patch", "search_files"], "tts_tools": ["text_to_speech"], @@ -161,11 +176,12 @@ _LEGACY_TOOLSET_MAP = { # get_tool_definitions (the main schema provider) # ============================================================================= + def get_tool_definitions( - enabled_toolsets: List[str] = None, - disabled_toolsets: List[str] = None, + enabled_toolsets: list[str] = None, + disabled_toolsets: list[str] = None, quiet_mode: bool = False, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Get tool definitions for model API calls with toolset-based filtering. @@ -200,6 +216,7 @@ def get_tool_definitions( elif disabled_toolsets: from toolsets import get_all_toolsets + for ts_name in get_all_toolsets(): tools_to_include.update(resolve_toolset(ts_name)) @@ -219,6 +236,7 @@ def get_tool_definitions( print(f"⚠️ Unknown toolset: {toolset_name}") else: from toolsets import get_all_toolsets + for ts_name in get_all_toolsets(): tools_to_include.update(resolve_toolset(ts_name)) @@ -230,6 +248,7 @@ def get_tool_definitions( # execute_code" even when the user disabled the web toolset (#560-discord). if "execute_code" in tools_to_include: from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema + sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include dynamic_schema = build_execute_code_schema(sandbox_enabled) for i, td in enumerate(filtered_tools): @@ -263,9 +282,9 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"} def handle_function_call( function_name: str, - function_args: Dict[str, Any], - task_id: Optional[str] = None, - user_task: Optional[str] = None, + function_args: dict[str, Any], + task_id: str | None = None, + user_task: str | None = None, ) -> str: """ Main function call dispatcher that routes calls to the tool registry. @@ -285,13 +304,15 @@ def handle_function_call( if function_name == "execute_code": return registry.dispatch( - function_name, function_args, + function_name, + function_args, task_id=task_id, enabled_tools=_last_resolved_tool_names, ) return registry.dispatch( - function_name, function_args, + function_name, + function_args, task_id=task_id, user_task=user_task, ) @@ -306,26 +327,27 @@ def handle_function_call( # Backward-compat wrapper functions # ============================================================================= -def get_all_tool_names() -> List[str]: + +def get_all_tool_names() -> list[str]: """Return all registered tool names.""" return registry.get_all_tool_names() -def get_toolset_for_tool(tool_name: str) -> Optional[str]: +def get_toolset_for_tool(tool_name: str) -> str | None: """Return the toolset a tool belongs to.""" return registry.get_toolset_for_tool(tool_name) -def get_available_toolsets() -> Dict[str, dict]: +def get_available_toolsets() -> dict[str, dict]: """Return toolset availability info for UI display.""" return registry.get_available_toolsets() -def check_toolset_requirements() -> Dict[str, bool]: +def check_toolset_requirements() -> dict[str, bool]: """Return {toolset: available_bool} for every registered toolset.""" return registry.check_toolset_requirements() -def check_tool_availability(quiet: bool = False) -> Tuple[List[str], List[dict]]: +def check_tool_availability(quiet: bool = False) -> tuple[list[str], list[dict]]: """Return (available_toolsets, unavailable_info).""" return registry.check_tool_availability(quiet=quiet) diff --git a/pyproject.toml b/pyproject.toml index 01bdaf7e23..a225fe871d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] modal = ["swe-rex[modal]>=1.4.0"] daytona = ["daytona>=0.148.0"] -dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0"] +dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0", "ruff", "pre-commit", "watchfiles"] messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] cron = ["croniter"] slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] @@ -76,6 +76,46 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector [tool.setuptools.packages.find] include = ["tools", "hermes_cli", "gateway", "cron", "honcho_integration"] +[tool.ruff] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "UP", "B", "SIM"] +ignore = [ + "E402", # late imports — intentional throughout codebase + "E501", # line too long — handled by formatter where it can + "E731", # lambda assignments — used in registry pattern + "E741", # ambiguous variable name — existing patterns + "F811", # redefined unused — intentional overrides + "F841", # unused variable — cleanup separately + "B007", # unused loop variable — cleanup separately + "B904", # raise from — too noisy to gate on + "B905", # zip strict — cleanup separately + "B027", # empty method without abstract decorator + "SIM102", # collapsible if — readability preference + "SIM103", # needless bool — readability preference + "SIM105", # suppressible exception — existing pattern + "SIM108", # ternary — readability preference + "SIM110", # reimplemented builtin + "SIM112", # uncapitalized env var + "SIM115", # open file with context handler + "SIM117", # multiple with statements + "SIM118", # in-dict-keys — cleanup separately + "SIM212", # if-expr twisted arms +] + +[tool.ruff.lint.per-file-ignores] +"batch_runner.py" = ["F821"] +"tools/patch_parser.py" = ["F821"] +"gateway/run.py" = ["F821"] +"gateway/channel_directory.py" = ["F401"] +"hermes_cli/doctor.py" = ["F401"] +"tools/image_generation_tool.py" = ["F401"] + +[tool.ruff.lint.isort] +known-first-party = ["tools", "hermes_cli", "gateway", "agent", "cron"] + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ diff --git a/run_agent.py b/run_agent.py index c1f2623c83..9166cdb3a6 100644 --- a/run_agent.py +++ b/run_agent.py @@ -15,7 +15,7 @@ Features: Usage: from run_agent import AIAgent - + agent = AIAgent(base_url="http://localhost:30000/v1", model="claude-opus-4-20250514") response = agent.run_conversation("Tell me about the latest Python updates") """ @@ -24,27 +24,28 @@ import copy import hashlib import json import logging + logger = logging.getLogger(__name__) import os import random import re -import sys -import time import threading -from types import SimpleNamespace +import time import uuid -from typing import List, Dict, Any, Optional -from openai import OpenAI -import fire from datetime import datetime from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import fire # Load .env from ~/.hermes/.env first, then project root as dev fallback from dotenv import load_dotenv +from openai import OpenAI _hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) _user_env = _hermes_home / ".env" -_project_env = Path(__file__).parent / '.env' +_project_env = Path(__file__).parent / ".env" if _user_env.exists(): try: load_dotenv(dotenv_path=_user_env, encoding="utf-8") @@ -65,38 +66,49 @@ os.environ.setdefault("MSWEA_GLOBAL_CONFIG_DIR", str(_hermes_home)) os.environ.setdefault("MSWEA_SILENT_STARTUP", "1") # Import our tool system -from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements -from tools.terminal_tool import cleanup_vm -from tools.interrupt import set_interrupt as _set_interrupt -from tools.browser_tool import cleanup_browser -import requests - -from hermes_constants import OPENROUTER_BASE_URL, OPENROUTER_MODELS_URL +from agent.context_compressor import ContextCompressor +from agent.display import ( + KawaiiSpinner, + _detect_tool_failure, +) +from agent.display import ( + build_tool_preview as _build_tool_preview, +) +from agent.display import ( + get_cute_tool_message as _get_cute_tool_message_impl, +) +from agent.model_metadata import ( + estimate_messages_tokens_rough, + estimate_tokens_rough, + get_next_probe_tier, + parse_context_limit_from_error, + save_context_length, +) # Agent internals extracted to agent/ package for modularity from agent.prompt_builder import ( - DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, - MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, + DEFAULT_AGENT_IDENTITY, + MEMORY_GUIDANCE, + PLATFORM_HINTS, + SESSION_SEARCH_GUIDANCE, + SKILLS_GUIDANCE, + build_context_files_prompt, + build_skills_system_prompt, ) -from agent.model_metadata import ( - fetch_model_metadata, get_model_context_length, - estimate_tokens_rough, estimate_messages_tokens_rough, - get_next_probe_tier, parse_context_limit_from_error, - save_context_length, -) -from agent.context_compressor import ContextCompressor from agent.prompt_caching import apply_anthropic_cache_control -from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt -from agent.display import ( - KawaiiSpinner, build_tool_preview as _build_tool_preview, - get_cute_tool_message as _get_cute_tool_message_impl, - _detect_tool_failure, +from agent.trajectory import ( + convert_scratchpad_to_think, + has_incomplete_scratchpad, ) from agent.trajectory import ( - convert_scratchpad_to_think, has_incomplete_scratchpad, save_trajectory as _save_trajectory_to_file, ) +from hermes_constants import OPENROUTER_BASE_URL +from model_tools import check_toolset_requirements, get_tool_definitions, handle_function_call +from tools.browser_tool import cleanup_browser +from tools.interrupt import set_interrupt as _set_interrupt +from tools.terminal_tool import cleanup_vm class IterationBudget: @@ -142,11 +154,11 @@ class IterationBudget: class AIAgent: """ AI Agent with tool calling capabilities. - + This class manages the conversation flow, tool execution, and response handling for AI models that support function calling. """ - + def __init__( self, base_url: str = None, @@ -156,17 +168,17 @@ class AIAgent: model: str = "anthropic/claude-opus-4.6", # OpenRouter format max_iterations: int = 90, # Default tool-calling iterations (shared with subagents) tool_delay: float = 1.0, - enabled_toolsets: List[str] = None, - disabled_toolsets: List[str] = None, + enabled_toolsets: list[str] = None, + disabled_toolsets: list[str] = None, save_trajectories: bool = False, verbose_logging: bool = False, quiet_mode: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, log_prefix: str = "", - providers_allowed: List[str] = None, - providers_ignored: List[str] = None, - providers_order: List[str] = None, + providers_allowed: list[str] = None, + providers_ignored: list[str] = None, + providers_order: list[str] = None, provider_sort: str = None, provider_require_parameters: bool = False, provider_data_collection: str = None, @@ -175,15 +187,15 @@ class AIAgent: clarify_callback: callable = None, step_callback: callable = None, max_tokens: int = None, - reasoning_config: Dict[str, Any] = None, - prefill_messages: List[Dict[str, Any]] = None, + reasoning_config: dict[str, Any] = None, + prefill_messages: list[dict[str, Any]] = None, platform: str = None, skip_context_files: bool = False, skip_memory: bool = False, session_db=None, honcho_session_key: str = None, iteration_budget: "IterationBudget" = None, - fallback_model: Dict[str, Any] = None, + fallback_model: dict[str, Any] = None, ): """ Initialize the AI Agent. @@ -259,15 +271,15 @@ class AIAgent: self.clarify_callback = clarify_callback self.step_callback = step_callback self._last_reported_tool = None # Track for "new tool" mode - + # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False self._interrupt_message = None # Optional message that triggered interrupt - + # Subagent delegation state - self._delegate_depth = 0 # 0 = top-level agent, incremented for children - self._active_children = [] # Running child AIAgents (for interrupt propagation) - + self._delegate_depth = 0 # 0 = top-level agent, incremented for children + self._active_children = [] # Running child AIAgents (for interrupt propagation) + # Store OpenRouter provider preferences self.providers_allowed = providers_allowed self.providers_ignored = providers_ignored @@ -279,12 +291,12 @@ class AIAgent: # Store toolset filtering options self.enabled_toolsets = enabled_toolsets self.disabled_toolsets = disabled_toolsets - + # Model response configuration self.max_tokens = max_tokens # None = use model default self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter) self.prefill_messages = prefill_messages or [] # Prefilled conversation turns - + # Anthropic prompt caching: auto-enabled for Claude models via OpenRouter. # Reduces input costs by ~75% on multi-turn conversations by caching the # conversation prefix. Uses system_and_3 strategy (4 breakpoints). @@ -292,90 +304,94 @@ class AIAgent: is_claude = "claude" in self.model.lower() self._use_prompt_caching = is_openrouter and is_claude self._cache_ttl = "5m" # Default 5-minute TTL (1.25x write cost) - + # Persistent error log -- always writes WARNING+ to ~/.hermes/logs/errors.log # so tool failures, API errors, etc. are inspectable after the fact. from agent.redact import RedactingFormatter + _error_log_dir = Path.home() / ".hermes" / "logs" _error_log_dir.mkdir(parents=True, exist_ok=True) _error_log_path = _error_log_dir / "errors.log" from logging.handlers import RotatingFileHandler + _error_file_handler = RotatingFileHandler( - _error_log_path, maxBytes=2 * 1024 * 1024, backupCount=2, + _error_log_path, + maxBytes=2 * 1024 * 1024, + backupCount=2, ) _error_file_handler.setLevel(logging.WARNING) - _error_file_handler.setFormatter(RedactingFormatter( - '%(asctime)s %(levelname)s %(name)s: %(message)s', - )) + _error_file_handler.setFormatter( + RedactingFormatter( + "%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + ) logging.getLogger().addHandler(_error_file_handler) if self.verbose_logging: logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%H:%M:%S' + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" ) for handler in logging.getLogger().handlers: - handler.setFormatter(RedactingFormatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%H:%M:%S', - )) + handler.setFormatter( + RedactingFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", + ) + ) # Keep third-party libraries at WARNING level to reduce noise # We have our own retry and error logging that's more informative - logging.getLogger('openai').setLevel(logging.WARNING) - logging.getLogger('openai._base_client').setLevel(logging.WARNING) - logging.getLogger('httpx').setLevel(logging.WARNING) - logging.getLogger('httpcore').setLevel(logging.WARNING) - logging.getLogger('asyncio').setLevel(logging.WARNING) + logging.getLogger("openai").setLevel(logging.WARNING) + logging.getLogger("openai._base_client").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) # Suppress Modal/gRPC related debug spam - logging.getLogger('hpack').setLevel(logging.WARNING) - logging.getLogger('hpack.hpack').setLevel(logging.WARNING) - logging.getLogger('grpc').setLevel(logging.WARNING) - logging.getLogger('modal').setLevel(logging.WARNING) - logging.getLogger('rex-deploy').setLevel(logging.INFO) # Keep INFO for sandbox status + logging.getLogger("hpack").setLevel(logging.WARNING) + logging.getLogger("hpack.hpack").setLevel(logging.WARNING) + logging.getLogger("grpc").setLevel(logging.WARNING) + logging.getLogger("modal").setLevel(logging.WARNING) + logging.getLogger("rex-deploy").setLevel(logging.INFO) # Keep INFO for sandbox status logger.info("Verbose logging enabled (third-party library logs suppressed)") else: # Set logging to INFO level for important messages only logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%H:%M:%S' + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" ) # Suppress noisy library logging - logging.getLogger('openai').setLevel(logging.ERROR) - logging.getLogger('openai._base_client').setLevel(logging.ERROR) - logging.getLogger('httpx').setLevel(logging.ERROR) - logging.getLogger('httpcore').setLevel(logging.ERROR) + logging.getLogger("openai").setLevel(logging.ERROR) + logging.getLogger("openai._base_client").setLevel(logging.ERROR) + logging.getLogger("httpx").setLevel(logging.ERROR) + logging.getLogger("httpcore").setLevel(logging.ERROR) if self.quiet_mode: # In quiet mode (CLI default), suppress all tool/infra log # noise. The TUI has its own rich display for status; logger # INFO/WARNING messages just clutter it. for quiet_logger in [ - 'tools', # all tools.* (terminal, browser, web, file, etc.) - 'minisweagent', # mini-swe-agent execution backend - 'run_agent', # agent runner internals - 'trajectory_compressor', - 'cron', # scheduler (only relevant in daemon mode) - 'hermes_cli', # CLI helpers + "tools", # all tools.* (terminal, browser, web, file, etc.) + "minisweagent", # mini-swe-agent execution backend + "run_agent", # agent runner internals + "trajectory_compressor", + "cron", # scheduler (only relevant in daemon mode) + "hermes_cli", # CLI helpers ]: logging.getLogger(quiet_logger).setLevel(logging.ERROR) - + # Initialize OpenAI client - defaults to OpenRouter client_kwargs = {} - + # Default to OpenRouter if no base_url provided if base_url: client_kwargs["base_url"] = base_url else: client_kwargs["base_url"] = OPENROUTER_BASE_URL - + # Handle API key - OpenRouter is the primary provider if api_key: client_kwargs["api_key"] = api_key else: # Primary: OPENROUTER_API_KEY, fallback to direct provider keys client_kwargs["api_key"] = os.getenv("OPENROUTER_API_KEY", "") - + # OpenRouter app attribution — shows hermes-agent in rankings/analytics effective_base = client_kwargs.get("base_url", "") if "openrouter" in effective_base.lower(): @@ -390,7 +406,7 @@ class AIAgent: client_kwargs["default_headers"] = { "User-Agent": "KimiCLI/1.0", } - + self._client_kwargs = client_kwargs # stored for rebuilding after interrupt try: self.client = OpenAI(**client_kwargs) @@ -403,10 +419,12 @@ class AIAgent: if key_used and key_used != "dummy-key" and len(key_used) > 12: print(f"🔑 Using API key: {key_used[:8]}...{key_used[-4:]}") else: - print(f"⚠️ Warning: API key appears invalid or missing (got: '{key_used[:20] if key_used else 'none'}...')") + print( + f"⚠️ Warning: API key appears invalid or missing (got: '{key_used[:20] if key_used else 'none'}...')" + ) except Exception as e: raise RuntimeError(f"Failed to initialize OpenAI client: {e}") - + # Provider fallback — a single backup model/provider tried when the # primary is exhausted (rate-limit, overload, connection failure). # Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"} @@ -424,7 +442,7 @@ class AIAgent: disabled_toolsets=disabled_toolsets, quiet_mode=self.quiet_mode, ) - + # Show tool configuration and store valid tool names for validation self.valid_tool_names = set() if self.tools: @@ -432,7 +450,7 @@ class AIAgent: tool_names = sorted(self.valid_tool_names) if not self.quiet_mode: print(f"🛠️ Loaded {len(self.tools)} tools: {', '.join(tool_names)}") - + # Show filtering info if applied if enabled_toolsets: print(f" ✅ Enabled toolsets: {', '.join(enabled_toolsets)}") @@ -440,27 +458,31 @@ class AIAgent: print(f" ❌ Disabled toolsets: {', '.join(disabled_toolsets)}") elif not self.quiet_mode: print("🛠️ No tools loaded (all tools filtered out or unavailable)") - + # Check tool requirements if self.tools and not self.quiet_mode: requirements = check_toolset_requirements() missing_reqs = [name for name, available in requirements.items() if not available] if missing_reqs: print(f"⚠️ Some tools may not work due to missing requirements: {missing_reqs}") - + # Show trajectory saving status if self.save_trajectories and not self.quiet_mode: print("📝 Trajectory saving enabled") - + # Show ephemeral system prompt status if self.ephemeral_system_prompt and not self.quiet_mode: - prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt + prompt_preview = ( + self.ephemeral_system_prompt[:60] + "..." + if len(self.ephemeral_system_prompt) > 60 + else self.ephemeral_system_prompt + ) print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)") - + # Show prompt caching status if self._use_prompt_caching and not self.quiet_mode: print(f"💾 Prompt caching: ENABLED (Claude via OpenRouter, {self._cache_ttl} TTL)") - + # Session logging setup - auto-save conversation trajectories for debugging self.session_start = datetime.now() if session_id: @@ -471,19 +493,19 @@ class AIAgent: timestamp_str = self.session_start.strftime("%Y%m%d_%H%M%S") short_uuid = uuid.uuid4().hex[:6] self.session_id = f"{timestamp_str}_{short_uuid}" - + # Session logs go into ~/.hermes/sessions/ alongside gateway sessions hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) self.logs_dir = hermes_home / "sessions" self.logs_dir.mkdir(parents=True, exist_ok=True) self.session_log_file = self.logs_dir / f"session_{self.session_id}.json" - + # Track conversation messages for session logging - self._session_messages: List[Dict[str, Any]] = [] - + self._session_messages: list[dict[str, Any]] = [] + # Cached system prompt -- built once per session, only rebuilt on compression - self._cached_system_prompt: Optional[str] = None - + self._cached_system_prompt: str | None = None + # SQLite session store (optional -- provided by CLI or gateway) self._session_db = session_db if self._session_db: @@ -501,11 +523,12 @@ class AIAgent: ) except Exception as e: logger.debug("Session DB create_session failed: %s", e) - + # In-memory todo list for task planning (one per agent/session) from tools.todo_tool import TodoStore + self._todo_store = TodoStore() - + # Persistent memory (MEMORY.md + USER.md) -- loaded from disk self._memory_store = None self._memory_enabled = False @@ -515,6 +538,7 @@ class AIAgent: if not skip_memory: try: from hermes_cli.config import load_config as _load_mem_config + mem_config = _load_mem_config().get("memory", {}) self._memory_enabled = mem_config.get("memory_enabled", False) self._user_profile_enabled = mem_config.get("user_profile_enabled", False) @@ -522,6 +546,7 @@ class AIAgent: self._memory_flush_min_turns = int(mem_config.get("flush_min_turns", 6)) if self._memory_enabled or self._user_profile_enabled: from tools.memory_tool import MemoryStore + self._memory_store = MemoryStore( memory_char_limit=mem_config.get("memory_char_limit", 2200), user_char_limit=mem_config.get("user_char_limit", 1375), @@ -529,7 +554,7 @@ class AIAgent: self._memory_store.load_from_disk() except Exception: pass # Memory is optional -- don't break agent init - + # Honcho AI-native memory (cross-session user modeling) # Reads ~/.honcho/config.json as the single source of truth. self._honcho = None # HonchoSessionManager | None @@ -537,9 +562,11 @@ class AIAgent: if not skip_memory: try: from honcho_integration.client import HonchoClientConfig, get_honcho_client + hcfg = HonchoClientConfig.from_global_config() if hcfg.enabled and hcfg.api_key: from honcho_integration.session import HonchoSessionManager + client = get_honcho_client(hcfg) self._honcho = HonchoSessionManager( honcho=client, @@ -548,18 +575,18 @@ class AIAgent: ) # Resolve session key: explicit arg > global sessions map > fallback if not self._honcho_session_key: - self._honcho_session_key = ( - hcfg.resolve_session_name() - or "hermes-default" - ) + self._honcho_session_key = hcfg.resolve_session_name() or "hermes-default" # Ensure session exists in Honcho self._honcho.get_or_create(self._honcho_session_key) # Inject session context into the honcho tool module from tools.honcho_tools import set_session_context + set_session_context(self._honcho, self._honcho_session_key) logger.info( "Honcho active (session: %s, user: %s, workspace: %s)", - self._honcho_session_key, hcfg.peer_name, hcfg.workspace_id, + self._honcho_session_key, + hcfg.peer_name, + hcfg.workspace_id, ) else: if not hcfg.enabled: @@ -574,18 +601,19 @@ class AIAgent: self._skill_nudge_interval = 15 try: from hermes_cli.config import load_config as _load_skills_config + skills_config = _load_skills_config().get("skills", {}) self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15)) except Exception: pass - + # Initialize context compressor for automatic context management # Compresses conversation when approaching model's context limit # Configuration via config.yaml (compression section) or environment variables compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85")) compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes") compression_summary_model = os.getenv("CONTEXT_COMPRESSION_MODEL") or None - + self.context_compressor = ContextCompressor( model=self.model, threshold_percent=compression_threshold, @@ -604,24 +632,25 @@ class AIAgent: self.session_completion_tokens = 0 self.session_total_tokens = 0 self.session_api_calls = 0 - + if not self.quiet_mode: if compression_enabled: - print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (compress at {int(compression_threshold*100)}% = {self.context_compressor.threshold_tokens:,})") + print( + f"📊 Context limit: {self.context_compressor.context_length:,} tokens (compress at {int(compression_threshold * 100)}% = {self.context_compressor.threshold_tokens:,})" + ) else: - print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)") - + print( + f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)" + ) + def _max_tokens_param(self, value: int) -> dict: """Return the correct max tokens kwarg for the current provider. - + OpenAI's newer models (gpt-4o, o-series, gpt-5+) require 'max_completion_tokens'. OpenRouter, local models, and older OpenAI models use 'max_tokens'. """ - _is_direct_openai = ( - "api.openai.com" in self.base_url.lower() - and "openrouter" not in self.base_url.lower() - ) + _is_direct_openai = "api.openai.com" in self.base_url.lower() and "openrouter" not in self.base_url.lower() if _is_direct_openai: return {"max_completion_tokens": value} return {"max_tokens": value} @@ -629,36 +658,36 @@ class AIAgent: def _has_content_after_think_block(self, content: str) -> bool: """ Check if content has actual text after any blocks. - + This detects cases where the model only outputs reasoning but no actual response, which indicates an incomplete generation that should be retried. - + Args: content: The assistant message content to check - + Returns: True if there's meaningful content after think blocks, False otherwise """ if not content: return False - + # Remove all ... blocks (including nested ones, non-greedy) - cleaned = re.sub(r'.*?', '', content, flags=re.DOTALL) - + cleaned = re.sub(r".*?", "", content, flags=re.DOTALL) + # Check if there's any non-whitespace content remaining return bool(cleaned.strip()) - + def _strip_think_blocks(self, content: str) -> str: """Remove ... blocks from content, returning only visible text.""" if not content: return "" - return re.sub(r'.*?', '', content, flags=re.DOTALL) + return re.sub(r".*?", "", content, flags=re.DOTALL) def _looks_like_codex_intermediate_ack( self, user_message: str, assistant_content: str, - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], ) -> bool: """Detect a planning/ack message that should continue instead of ending the turn.""" if any(isinstance(msg, dict) and msg.get("role") == "tool" for msg in messages): @@ -715,60 +744,55 @@ class AIAgent: user_text = (user_message or "").strip().lower() user_targets_workspace = ( - any(marker in user_text for marker in workspace_markers) - or "~/" in user_text - or "/" in user_text + any(marker in user_text for marker in workspace_markers) or "~/" in user_text or "/" in user_text ) assistant_mentions_action = any(marker in assistant_text for marker in action_markers) - assistant_targets_workspace = any( - marker in assistant_text for marker in workspace_markers - ) + assistant_targets_workspace = any(marker in assistant_text for marker in workspace_markers) return (user_targets_workspace or assistant_targets_workspace) and assistant_mentions_action - - - def _extract_reasoning(self, assistant_message) -> Optional[str]: + + def _extract_reasoning(self, assistant_message) -> str | None: """ Extract reasoning/thinking content from an assistant message. - + OpenRouter and various providers can return reasoning in multiple formats: 1. message.reasoning - Direct reasoning field (DeepSeek, Qwen, etc.) 2. message.reasoning_content - Alternative field (Moonshot AI, Novita, etc.) 3. message.reasoning_details - Array of {type, summary, ...} objects (OpenRouter unified) - + Args: assistant_message: The assistant message object from the API response - + Returns: Combined reasoning text, or None if no reasoning found """ reasoning_parts = [] - + # Check direct reasoning field - if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning: + if hasattr(assistant_message, "reasoning") and assistant_message.reasoning: reasoning_parts.append(assistant_message.reasoning) - + # Check reasoning_content field (alternative name used by some providers) - if hasattr(assistant_message, 'reasoning_content') and assistant_message.reasoning_content: + if hasattr(assistant_message, "reasoning_content") and assistant_message.reasoning_content: # Don't duplicate if same as reasoning if assistant_message.reasoning_content not in reasoning_parts: reasoning_parts.append(assistant_message.reasoning_content) - + # Check reasoning_details array (OpenRouter unified format) # Format: [{"type": "reasoning.summary", "summary": "...", ...}, ...] - if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details: + if hasattr(assistant_message, "reasoning_details") and assistant_message.reasoning_details: for detail in assistant_message.reasoning_details: if isinstance(detail, dict): # Extract summary from reasoning detail object - summary = detail.get('summary') or detail.get('content') or detail.get('text') + summary = detail.get("summary") or detail.get("content") or detail.get("text") if summary and summary not in reasoning_parts: reasoning_parts.append(summary) - + # Combine all reasoning parts if reasoning_parts: return "\n\n".join(reasoning_parts) - + return None - + def _cleanup_task_resources(self, task_id: str) -> None: """Clean up VM and browser resources for a given task.""" try: @@ -782,7 +806,7 @@ class AIAgent: if self.verbose_logging: logging.warning(f"Failed to cleanup browser for task {task_id}: {e}") - def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None): + def _persist_session(self, messages: list[dict], conversation_history: list[dict] = None): """Save session state to both JSON log and SQLite on any exit path. Ensures conversations are never lost, even on errors or early returns. @@ -791,7 +815,7 @@ class AIAgent: self._save_session_log(messages) self._flush_messages_to_session_db(messages, conversation_history) - def _log_msg_to_db(self, msg: Dict): + def _log_msg_to_db(self, msg: dict): """Log a single message to SQLite immediately. Called after each messages.append().""" if not self._session_db: return @@ -801,8 +825,7 @@ class AIAgent: tool_calls_data = None if hasattr(msg, "tool_calls") and msg.tool_calls: tool_calls_data = [ - {"name": tc.function.name, "arguments": tc.function.arguments} - for tc in msg.tool_calls + {"name": tc.function.name, "arguments": tc.function.arguments} for tc in msg.tool_calls ] elif isinstance(msg.get("tool_calls"), list): tool_calls_data = msg["tool_calls"] @@ -818,7 +841,7 @@ class AIAgent: except Exception as e: logger.debug("Session DB log_msg failed: %s", e) - def _flush_messages_to_session_db(self, messages: List[Dict], conversation_history: List[Dict] = None): + def _flush_messages_to_session_db(self, messages: list[dict], conversation_history: list[dict] = None): """Persist any un-logged messages to the SQLite session store. Called both at the normal end of run_conversation and from every early- @@ -835,8 +858,7 @@ class AIAgent: tool_calls_data = None if hasattr(msg, "tool_calls") and msg.tool_calls: tool_calls_data = [ - {"name": tc.function.name, "arguments": tc.function.arguments} - for tc in msg.tool_calls + {"name": tc.function.name, "arguments": tc.function.arguments} for tc in msg.tool_calls ] elif isinstance(msg.get("tool_calls"), list): tool_calls_data = msg["tool_calls"] @@ -852,47 +874,47 @@ class AIAgent: except Exception as e: logger.debug("Session DB append_message failed: %s", e) - def _get_messages_up_to_last_assistant(self, messages: List[Dict]) -> List[Dict]: + def _get_messages_up_to_last_assistant(self, messages: list[dict]) -> list[dict]: """ Get messages up to (but not including) the last assistant turn. - + This is used when we need to "roll back" to the last successful point in the conversation, typically when the final assistant message is incomplete or malformed. - + Args: messages: Full message list - + Returns: Messages up to the last complete assistant turn (ending with user/tool message) """ if not messages: return [] - + # Find the index of the last assistant message last_assistant_idx = None for i in range(len(messages) - 1, -1, -1): if messages[i].get("role") == "assistant": last_assistant_idx = i break - + if last_assistant_idx is None: # No assistant message found, return all messages return messages.copy() - + # Return everything up to (not including) the last assistant message return messages[:last_assistant_idx] - + def _format_tools_for_system_message(self) -> str: """ Format tool definitions for the system message in the trajectory format. - + Returns: str: JSON string representation of tool definitions """ if not self.tools: return "[]" - + # Convert tool definitions to the format expected in trajectories formatted_tools = [] for tool in self.tools: @@ -901,26 +923,28 @@ class AIAgent: "name": func["name"], "description": func.get("description", ""), "parameters": func.get("parameters", {}), - "required": None # Match the format in the example + "required": None, # Match the format in the example } formatted_tools.append(formatted_tool) - + return json.dumps(formatted_tools, ensure_ascii=False) - - def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]: + + def _convert_to_trajectory_format( + self, messages: list[dict[str, Any]], user_query: str, completed: bool + ) -> list[dict[str, Any]]: """ Convert internal message format to trajectory format for saving. - + Args: messages (List[Dict]): Internal message history user_query (str): Original user query completed (bool): Whether the conversation completed successfully - + Returns: List[Dict]: Messages in trajectory format """ trajectory = [] - + # Add system message with tool definitions system_msg = ( "You are a function calling AI model. You are provided with function signatures within XML tags. " @@ -935,78 +959,72 @@ class AIAgent: "Each function call should be enclosed within XML tags.\n" "Example:\n\n{'name': ,'arguments': }\n" ) - - trajectory.append({ - "from": "system", - "value": system_msg - }) - + + trajectory.append({"from": "system", "value": system_msg}) + # Add the actual user prompt (from the dataset) as the first human message - trajectory.append({ - "from": "human", - "value": user_query - }) - + trajectory.append({"from": "human", "value": user_query}) + # Skip the first message (the user query) since we already added it above. # Prefill messages are injected at API-call time only (not in the messages # list), so no offset adjustment is needed here. i = 1 - + while i < len(messages): msg = messages[i] - + if msg["role"] == "assistant": # Check if this message has tool calls if "tool_calls" in msg and msg["tool_calls"]: # Format assistant message with tool calls # Add tags around reasoning for trajectory storage content = "" - + # Prepend reasoning in tags if available (native thinking tokens) if msg.get("reasoning") and msg["reasoning"].strip(): content = f"\n{msg['reasoning']}\n\n" - + if msg.get("content") and msg["content"].strip(): # Convert any tags to tags # (used when native thinking is disabled and model reasons via XML) content += convert_scratchpad_to_think(msg["content"]) + "\n" - + # Add tool calls wrapped in XML tags for tool_call in msg["tool_calls"]: # Parse arguments - should always succeed since we validate during conversation # but keep try-except as safety net try: - arguments = json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] + arguments = ( + json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"] + ) except json.JSONDecodeError: # This shouldn't happen since we validate and retry during conversation, # but if it does, log warning and use empty dict - logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}") + logging.warning( + f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}" + ) arguments = {} - - tool_call_json = { - "name": tool_call["function"]["name"], - "arguments": arguments - } + + tool_call_json = {"name": tool_call["function"]["name"], "arguments": arguments} content += f"\n{json.dumps(tool_call_json, ensure_ascii=False)}\n\n" - + # Ensure every gpt turn has a block (empty if no reasoning) # so the format is consistent for training data if "" not in content: content = "\n\n" + content - - trajectory.append({ - "from": "gpt", - "value": content.rstrip() - }) - + + trajectory.append({"from": "gpt", "value": content.rstrip()}) + # Collect all subsequent tool responses tool_responses = [] j = i + 1 while j < len(messages) and messages[j]["role"] == "tool": tool_msg = messages[j] # Format tool response with XML tags - tool_response = f"\n" - + tool_response = "\n" + # Try to parse tool content as JSON if it looks like JSON tool_content = tool_msg["content"] try: @@ -1014,61 +1032,57 @@ class AIAgent: tool_content = json.loads(tool_content) except (json.JSONDecodeError, AttributeError): pass # Keep as string if not valid JSON - - tool_response += json.dumps({ - "tool_call_id": tool_msg.get("tool_call_id", ""), - "name": msg["tool_calls"][len(tool_responses)]["function"]["name"] if len(tool_responses) < len(msg["tool_calls"]) else "unknown", - "content": tool_content - }, ensure_ascii=False) + + tool_response += json.dumps( + { + "tool_call_id": tool_msg.get("tool_call_id", ""), + "name": msg["tool_calls"][len(tool_responses)]["function"]["name"] + if len(tool_responses) < len(msg["tool_calls"]) + else "unknown", + "content": tool_content, + }, + ensure_ascii=False, + ) tool_response += "\n" tool_responses.append(tool_response) j += 1 - + # Add all tool responses as a single message if tool_responses: - trajectory.append({ - "from": "tool", - "value": "\n".join(tool_responses) - }) + trajectory.append({"from": "tool", "value": "\n".join(tool_responses)}) i = j - 1 # Skip the tool messages we just processed - + else: # Regular assistant message without tool calls # Add tags around reasoning for trajectory storage content = "" - + # Prepend reasoning in tags if available (native thinking tokens) if msg.get("reasoning") and msg["reasoning"].strip(): content = f"\n{msg['reasoning']}\n\n" - + # Convert any tags to tags # (used when native thinking is disabled and model reasons via XML) raw_content = msg["content"] or "" content += convert_scratchpad_to_think(raw_content) - + # Ensure every gpt turn has a block (empty if no reasoning) if "" not in content: content = "\n\n" + content - - trajectory.append({ - "from": "gpt", - "value": content.strip() - }) - + + trajectory.append({"from": "gpt", "value": content.strip()}) + elif msg["role"] == "user": - trajectory.append({ - "from": "human", - "value": msg["content"] - }) - + trajectory.append({"from": "human", "value": msg["content"]}) + i += 1 - + return trajectory - - def _save_trajectory(self, messages: List[Dict[str, Any]], user_query: str, completed: bool): + + def _save_trajectory(self, messages: list[dict[str, Any]], user_query: str, completed: bool): """ Save conversation trajectory to JSONL file. - + Args: messages (List[Dict]): Complete message history user_query (str): Original user query @@ -1076,11 +1090,11 @@ class AIAgent: """ if not self.save_trajectories: return - + trajectory = self._convert_to_trajectory_format(messages, user_query, completed) _save_trajectory_to_file(trajectory, self.model, completed) - - def _mask_api_key_for_logs(self, key: Optional[str]) -> Optional[str]: + + def _mask_api_key_for_logs(self, key: str | None) -> str | None: if not key: return None if len(key) <= 12: @@ -1089,11 +1103,11 @@ class AIAgent: def _dump_api_request_debug( self, - api_kwargs: Dict[str, Any], + api_kwargs: dict[str, Any], *, reason: str, - error: Optional[Exception] = None, - ) -> Optional[Path]: + error: Exception | None = None, + ) -> Path | None: """ Dump a debug-friendly HTTP request record for chat.completions.create(). @@ -1112,7 +1126,7 @@ class AIAgent: except Exception as e: logger.debug("Could not extract API key for debug dump: %s", e) - dump_payload: Dict[str, Any] = { + dump_payload: dict[str, Any] = { "timestamp": datetime.now().isoformat(), "session_id": self.session_id, "reason": reason, @@ -1128,7 +1142,7 @@ class AIAgent: } if error is not None: - error_info: Dict[str, Any] = { + error_info: dict[str, Any] = { "type": type(error).__name__, "message": str(error), } @@ -1175,11 +1189,11 @@ class AIAgent: if not content: return content content = convert_scratchpad_to_think(content) - content = re.sub(r'\n+()', r'\n\1', content) - content = re.sub(r'()\n+', r'\1\n', content) + content = re.sub(r"\n+()", r"\n\1", content) + content = re.sub(r"()\n+", r"\1\n", content) return content.strip() - def _save_session_log(self, messages: List[Dict[str, Any]] = None): + def _save_session_log(self, messages: list[dict[str, Any]] = None): """ Save the full raw session to a JSON file. @@ -1223,26 +1237,26 @@ class AIAgent: except Exception as e: if self.verbose_logging: logging.warning(f"Failed to save session log: {e}") - + def interrupt(self, message: str = None) -> None: """ Request the agent to interrupt its current tool-calling loop. - + Call this from another thread (e.g., input handler, message receiver) to gracefully stop the agent and process a new message. - + Also signals long-running tool executions (e.g. terminal commands) to terminate early, so the agent can respond immediately. - + Args: message: Optional new message that triggered the interrupt. If provided, the agent will include this in its response context. - + Example (CLI): # In a separate input thread: if user_typed_something: agent.interrupt(user_input) - + Example (Messaging): # When new message arrives for active session: if session_has_running_agent: @@ -1259,18 +1273,21 @@ class AIAgent: except Exception as e: logger.debug("Failed to propagate interrupt to child agent: %s", e) if not self.quiet_mode: - print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "")) - + print( + "\n⚡ Interrupt requested" + + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "") + ) + def clear_interrupt(self) -> None: """Clear any pending interrupt request and the global tool interrupt signal.""" self._interrupt_requested = False self._interrupt_message = None _set_interrupt(False) - - def _hydrate_todo_store(self, history: List[Dict[str, Any]]) -> None: + + def _hydrate_todo_store(self, history: list[dict[str, Any]]) -> None: """ Recover todo state from conversation history. - + The gateway creates a fresh AIAgent per message, so the in-memory TodoStore is empty. We scan the history for the most recent todo tool response and replay it to reconstruct the state. @@ -1291,14 +1308,14 @@ class AIAgent: break except (json.JSONDecodeError, TypeError): continue - + if last_todo_response: # Replay the items into the store (replace mode) self._todo_store.write(last_todo_response, merge=False) if not self.quiet_mode: print(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history") _set_interrupt(False) - + @property def is_interrupted(self) -> bool: """Check if an interrupt has been requested.""" @@ -1343,11 +1360,13 @@ class AIAgent: session = self._honcho.get_or_create(self._honcho_session_key) session.add_message("user", f"[observation] {content.strip()}") self._honcho.save(session) - return json.dumps({ - "success": True, - "target": "user", - "message": "Saved to Honcho user model.", - }) + return json.dumps( + { + "success": True, + "target": "user", + "message": "Saved to Honcho user model.", + } + ) except Exception as e: logger.debug("Honcho user observation failed: %s", e) return json.dumps({"success": False, "error": f"Honcho save failed: {e}"}) @@ -1367,7 +1386,7 @@ class AIAgent: def _build_system_prompt(self, system_message: str = None) -> str: """ Assemble the full system prompt from all layers. - + Called once per session (cached on self._cached_system_prompt) and only rebuilt after context compression events. This ensures the system prompt is stable across all turns in a session, maximizing prefix cache hits. @@ -1409,7 +1428,7 @@ class AIAgent: if user_block: prompt_parts.append(user_block) - has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage']) + has_skills_tools = any(name in self.valid_tool_names for name in ["skills_list", "skill_view", "skill_manage"]) skills_prompt = build_skills_system_prompt() if has_skills_tools else "" if skills_prompt: prompt_parts.append(skills_prompt) @@ -1420,21 +1439,20 @@ class AIAgent: prompt_parts.append(context_files_prompt) from hermes_time import now as _hermes_now + now = _hermes_now() - prompt_parts.append( - f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}" - ) + prompt_parts.append(f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}") platform_key = (self.platform or "").lower().strip() if platform_key in PLATFORM_HINTS: prompt_parts.append(PLATFORM_HINTS[platform_key]) return "\n\n".join(prompt_parts) - + def _invalidate_system_prompt(self): """ Invalidate the cached system prompt, forcing a rebuild on the next turn. - + Called after context compression events. Also reloads memory from disk so the rebuilt prompt captures any writes from this session. """ @@ -1442,29 +1460,31 @@ class AIAgent: if self._memory_store: self._memory_store.load_from_disk() - def _responses_tools(self, tools: Optional[List[Dict[str, Any]]] = None) -> Optional[List[Dict[str, Any]]]: + def _responses_tools(self, tools: list[dict[str, Any]] | None = None) -> list[dict[str, Any]] | None: """Convert chat-completions tool schemas to Responses function-tool schemas.""" source_tools = tools if tools is not None else self.tools if not source_tools: return None - converted: List[Dict[str, Any]] = [] + converted: list[dict[str, Any]] = [] for item in source_tools: fn = item.get("function", {}) if isinstance(item, dict) else {} name = fn.get("name") if not isinstance(name, str) or not name.strip(): continue - converted.append({ - "type": "function", - "name": name, - "description": fn.get("description", ""), - "strict": False, - "parameters": fn.get("parameters", {"type": "object", "properties": {}}), - }) + converted.append( + { + "type": "function", + "name": name, + "description": fn.get("description", ""), + "strict": False, + "parameters": fn.get("parameters", {"type": "object", "properties": {}}), + } + ) return converted or None @staticmethod - def _split_responses_tool_id(raw_id: Any) -> tuple[Optional[str], Optional[str]]: + def _split_responses_tool_id(raw_id: Any) -> tuple[str | None, str | None]: """Split a stored tool id into (call_id, response_item_id).""" if not isinstance(raw_id, str): return None, None @@ -1483,7 +1503,7 @@ class AIAgent: def _derive_responses_function_call_id( self, call_id: str, - response_item_id: Optional[str] = None, + response_item_id: str | None = None, ) -> str: """Build a valid Responses `function_call.id` (must start with `fc_`).""" if isinstance(response_item_id, str): @@ -1495,13 +1515,13 @@ class AIAgent: if source.startswith("fc_"): return source if source.startswith("call_") and len(source) > len("call_"): - return f"fc_{source[len('call_'):]}" + return f"fc_{source[len('call_') :]}" sanitized = re.sub(r"[^A-Za-z0-9_-]", "", source) if sanitized.startswith("fc_"): return sanitized if sanitized.startswith("call_") and len(sanitized) > len("call_"): - return f"fc_{sanitized[len('call_'):]}" + return f"fc_{sanitized[len('call_') :]}" if sanitized: return f"fc_{sanitized[:48]}" @@ -1509,9 +1529,9 @@ class AIAgent: digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24] return f"fc_{digest}" - def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _chat_messages_to_responses_input(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert internal chat-style messages to Responses input items.""" - items: List[Dict[str, Any]] = [] + items: list[dict[str, Any]] = [] for msg in messages: if not isinstance(msg, dict): @@ -1546,9 +1566,7 @@ class AIAgent: if not isinstance(fn_name, str) or not fn_name.strip(): continue - embedded_call_id, embedded_response_item_id = self._split_responses_tool_id( - tc.get("id") - ) + embedded_call_id, embedded_response_item_id = self._split_responses_tool_id(tc.get("id")) call_id = tc.get("call_id") if not isinstance(call_id, str) or not call_id.strip(): call_id = embedded_call_id @@ -1558,7 +1576,7 @@ class AIAgent: and embedded_response_item_id.startswith("fc_") and len(embedded_response_item_id) > len("fc_") ): - call_id = f"call_{embedded_response_item_id[len('fc_'):]}" + call_id = f"call_{embedded_response_item_id[len('fc_') :]}" else: call_id = f"call_{uuid.uuid4().hex[:12]}" call_id = call_id.strip() @@ -1570,12 +1588,14 @@ class AIAgent: arguments = str(arguments) arguments = arguments.strip() or "{}" - items.append({ - "type": "function_call", - "call_id": call_id, - "name": fn_name, - "arguments": arguments, - }) + items.append( + { + "type": "function_call", + "call_id": call_id, + "name": fn_name, + "arguments": arguments, + } + ) continue items.append({"role": role, "content": content_text}) @@ -1589,19 +1609,21 @@ class AIAgent: call_id = raw_tool_call_id.strip() if not isinstance(call_id, str) or not call_id.strip(): continue - items.append({ - "type": "function_call_output", - "call_id": call_id, - "output": str(msg.get("content", "") or ""), - }) + items.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": str(msg.get("content", "") or ""), + } + ) return items - def _preflight_codex_input_items(self, raw_items: Any) -> List[Dict[str, Any]]: + def _preflight_codex_input_items(self, raw_items: Any) -> list[dict[str, Any]]: if not isinstance(raw_items, list): raise ValueError("Codex Responses input must be a list of input items.") - normalized: List[Dict[str, Any]] = [] + normalized: list[dict[str, Any]] = [] for idx, item in enumerate(raw_items): if not isinstance(item, dict): raise ValueError(f"Codex Responses input[{idx}] must be an object.") @@ -1688,7 +1710,7 @@ class AIAgent: api_kwargs: Any, *, allow_stream: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if not isinstance(api_kwargs, dict): raise ValueError("Codex Responses request must be a dict.") @@ -1755,10 +1777,17 @@ class AIAgent: raise ValueError("Codex Responses contract requires 'store' to be false.") allowed_keys = { - "model", "instructions", "input", "tools", "store", - "reasoning", "include", "max_output_tokens", "temperature", + "model", + "instructions", + "input", + "tools", + "store", + "reasoning", + "include", + "max_output_tokens", + "temperature", } - normalized: Dict[str, Any] = { + normalized: dict[str, Any] = { "model": model, "instructions": instructions, "input": normalized_input, @@ -1794,9 +1823,7 @@ class AIAgent: unexpected = sorted(key for key in api_kwargs.keys() if key not in allowed_keys) if unexpected: - raise ValueError( - f"Codex Responses request has unsupported field(s): {', '.join(unexpected)}." - ) + raise ValueError(f"Codex Responses request has unsupported field(s): {', '.join(unexpected)}.") return normalized @@ -1806,7 +1833,7 @@ class AIAgent: if not isinstance(content, list): return "" - chunks: List[str] = [] + chunks: list[str] = [] for part in content: ptype = getattr(part, "type", None) if ptype not in {"output_text", "text"}: @@ -1820,7 +1847,7 @@ class AIAgent: """Extract a compact reasoning text from a Responses reasoning item.""" summary = getattr(item, "summary", None) if isinstance(summary, list): - chunks: List[str] = [] + chunks: list[str] = [] for part in summary: text = getattr(part, "text", None) if isinstance(text, str) and text: @@ -1852,10 +1879,10 @@ class AIAgent: error_msg = str(error_obj) if error_obj else f"Responses API returned status '{response_status}'" raise RuntimeError(error_msg) - content_parts: List[str] = [] - reasoning_parts: List[str] = [] - reasoning_items_raw: List[Dict[str, Any]] = [] - tool_calls: List[Any] = [] + content_parts: list[str] = [] + reasoning_parts: list[str] = [] + reasoning_items_raw: list[dict[str, Any]] = [] + tool_calls: list[Any] = [] has_incomplete_items = response_status in {"queued", "in_progress", "incomplete"} saw_commentary_phase = False saw_final_answer_phase = False @@ -1921,13 +1948,15 @@ class AIAgent: call_id = call_id.strip() response_item_id = raw_item_id if isinstance(raw_item_id, str) else None response_item_id = self._derive_responses_function_call_id(call_id, response_item_id) - tool_calls.append(SimpleNamespace( - id=call_id, - call_id=call_id, - response_item_id=response_item_id, - type="function", - function=SimpleNamespace(name=fn_name, arguments=arguments), - )) + tool_calls.append( + SimpleNamespace( + id=call_id, + call_id=call_id, + response_item_id=response_item_id, + type="function", + function=SimpleNamespace(name=fn_name, arguments=arguments), + ) + ) elif item_type == "custom_tool_call": fn_name = getattr(item, "name", "") or "" arguments = getattr(item, "input", "{}") @@ -1942,13 +1971,15 @@ class AIAgent: call_id = call_id.strip() response_item_id = raw_item_id if isinstance(raw_item_id, str) else None response_item_id = self._derive_responses_function_call_id(call_id, response_item_id) - tool_calls.append(SimpleNamespace( - id=call_id, - call_id=call_id, - response_item_id=response_item_id, - type="function", - function=SimpleNamespace(name=fn_name, arguments=arguments), - )) + tool_calls.append( + SimpleNamespace( + id=call_id, + call_id=call_id, + response_item_id=response_item_id, + type="function", + function=SimpleNamespace(name=fn_name, arguments=arguments), + ) + ) final_text = "\n".join([p for p in content_parts if p]).strip() if not final_text and hasattr(response, "output_text"): @@ -2122,7 +2153,7 @@ class AIAgent: """ Run the API call in a background thread so the main conversation loop can detect interrupts without waiting for the full HTTP round-trip. - + On interrupt, closes the HTTP client to cancel the in-flight request (stops token generation and avoids wasting money), then rebuilds the client for future calls. @@ -2176,9 +2207,7 @@ class AIAgent: "nous": ("resolve_nous_runtime_credentials", "chat_completions"), } - def _resolve_fallback_credentials( - self, fb_provider: str, fb_config: dict - ) -> Optional[tuple]: + def _resolve_fallback_credentials(self, fb_provider: str, fb_config: dict) -> tuple | None: """Resolve credentials for a fallback provider. Returns (api_key, base_url, api_mode) on success, or None on failure. @@ -2192,13 +2221,15 @@ class AIAgent: resolver_name, api_mode = self._FALLBACK_OAUTH_PROVIDERS[fb_provider] try: import hermes_cli.auth as _auth + resolver = getattr(_auth, resolver_name) creds = resolver() return creds["api_key"], creds["base_url"], api_mode except Exception as e: logging.warning( "Fallback to %s failed (credential resolution): %s", - fb_provider, e, + fb_provider, + e, ) return None @@ -2273,18 +2304,14 @@ class AIAgent: self._fallback_activated = True # Re-evaluate prompt caching for the new provider/model - self._use_prompt_caching = ( - "openrouter" in fb_base_url.lower() - and "claude" in fb_model.lower() - ) + self._use_prompt_caching = "openrouter" in fb_base_url.lower() and "claude" in fb_model.lower() - print( - f"{self.log_prefix}🔄 Primary model failed — switching to fallback: " - f"{fb_model} via {fb_provider}" - ) + print(f"{self.log_prefix}🔄 Primary model failed — switching to fallback: {fb_model} via {fb_provider}") logging.info( "Fallback activated: %s → %s (%s)", - old_model, fb_model, fb_provider, + old_model, + fb_model, + fb_provider, ) return True except Exception as e: @@ -2369,10 +2396,7 @@ class AIAgent: if self.reasoning_config is not None: extra_body["reasoning"] = self.reasoning_config else: - extra_body["reasoning"] = { - "enabled": True, - "effort": "medium" - } + extra_body["reasoning"] = {"enabled": True, "effort": "medium"} # Nous Portal product attribution if _is_nous: @@ -2402,7 +2426,7 @@ class AIAgent: "finish_reason": finish_reason, } - if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details: + if hasattr(assistant_message, "reasoning_details") and assistant_message.reasoning_details: # Pass reasoning_details back unmodified so providers (OpenRouter, # Anthropic, OpenAI) can maintain reasoning continuity across turns. # Each provider may include opaque fields (signature, encrypted_content) @@ -2455,10 +2479,7 @@ class AIAgent: "call_id": call_id, "response_item_id": response_item_id, "type": tool_call.type, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments - }, + "function": {"name": tool_call.function.name, "arguments": tool_call.function.arguments}, } # Preserve extra_content (e.g. Gemini thought_signature) so it # is sent back on subsequent API calls. Without this, Gemini 3 @@ -2496,13 +2517,12 @@ class AIAgent: return if messages is None: - messages = getattr(self, '_session_messages', None) + messages = getattr(self, "_session_messages", None) if not messages or len(messages) < 3: return flush_content = ( - "[System: The session is being compressed. " - "Please save anything worth remembering to your memories.]" + "[System: The session is being compressed. Please save anything worth remembering to your memories.]" ) _sentinel = f"__flush_{id(self)}_{time.monotonic()}" flush_msg = {"role": "user", "content": flush_content, "_flush_sentinel": _sentinel} @@ -2527,7 +2547,7 @@ class AIAgent: # Make one API call with only the memory tool available memory_tool_def = None - for t in (self.tools or []): + for t in self.tools or []: if t.get("function", {}).get("name") == "memory": memory_tool_def = t break @@ -2539,6 +2559,7 @@ class AIAgent: # Use auxiliary client for the flush call when available -- # it's cheaper and avoids Codex Responses API incompatibility. from agent.auxiliary_client import get_text_auxiliary_client + aux_client, aux_model = get_text_auxiliary_client() if aux_client: @@ -2585,6 +2606,7 @@ class AIAgent: args = json.loads(tc.function.arguments) flush_target = args.get("target", "memory") from tools.memory_tool import memory_tool as _memory_tool + result = _memory_tool( action=args.get("action"), target=flush_target, @@ -2662,7 +2684,7 @@ class AIAgent: # If the user sent "stop" during a previous tool's execution, # do NOT start any more tools -- skip them all immediately. if self._interrupt_requested: - remaining_calls = assistant_message.tool_calls[i-1:] + remaining_calls = assistant_message.tool_calls[i - 1 :] if remaining_calls: print(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)") for skipped_tc in remaining_calls: @@ -2692,7 +2714,9 @@ class AIAgent: if not self.quiet_mode: args_str = json.dumps(function_args, ensure_ascii=False) - args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str + args_preview = ( + args_str[: self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str + ) print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") if self.tool_progress_callback: @@ -2706,6 +2730,7 @@ class AIAgent: if function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool + function_result = _todo_tool( todos=function_args.get("todos"), merge=function_args.get("merge", False), @@ -2713,12 +2738,15 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}") + print( + f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}" + ) elif function_name == "session_search": if not self._session_db: function_result = json.dumps({"success": False, "error": "Session database not available."}) else: from tools.session_search_tool import session_search as _session_search + function_result = _session_search( query=function_args.get("query", ""), role_filter=function_args.get("role_filter"), @@ -2728,10 +2756,13 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}") + print( + f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}" + ) elif function_name == "memory": target = function_args.get("target", "memory") from tools.memory_tool import memory_tool as _memory_tool + function_result = _memory_tool( action=function_args.get("action"), target=target, @@ -2744,9 +2775,12 @@ class AIAgent: self._honcho_save_user_observation(function_args.get("content", "")) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}") + print( + f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}" + ) elif function_name == "clarify": from tools.clarify_tool import clarify_tool as _clarify_tool + function_result = _clarify_tool( question=function_args.get("question", ""), choices=function_args.get("choices"), @@ -2754,9 +2788,12 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}") + print( + f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}" + ) elif function_name == "delegate_task": from tools.delegate_tool import delegate_task as _delegate_task + tasks_arg = function_args.get("tasks") if tasks_arg and isinstance(tasks_arg, list): spinner_label = f"🔀 delegating {len(tasks_arg)} tasks" @@ -2766,7 +2803,7 @@ class AIAgent: spinner = None if self.quiet_mode: face = random.choice(KawaiiSpinner.KAWAII_WAITING) - spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots') + spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type="dots") spinner.start() self._delegate_spinner = spinner _delegate_result = None @@ -2783,7 +2820,9 @@ class AIAgent: finally: self._delegate_spinner = None tool_duration = time.time() - tool_start_time - cute_msg = _get_cute_tool_message_impl('delegate_task', function_args, tool_duration, result=_delegate_result) + cute_msg = _get_cute_tool_message_impl( + "delegate_task", function_args, tool_duration, result=_delegate_result + ) if spinner: spinner.stop(cute_msg) elif self.quiet_mode: @@ -2791,26 +2830,47 @@ class AIAgent: elif self.quiet_mode: face = random.choice(KawaiiSpinner.KAWAII_WAITING) tool_emoji_map = { - 'web_search': '🔍', 'web_extract': '📄', 'web_crawl': '🕸️', - 'terminal': '💻', 'process': '⚙️', - 'read_file': '📖', 'write_file': '✍️', 'patch': '🔧', 'search_files': '🔎', - 'browser_navigate': '🌐', 'browser_snapshot': '📸', - 'browser_click': '👆', 'browser_type': '⌨️', - 'browser_scroll': '📜', 'browser_back': '◀️', - 'browser_press': '⌨️', 'browser_close': '🚪', - 'browser_get_images': '🖼️', 'browser_vision': '👁️', - 'image_generate': '🎨', 'text_to_speech': '🔊', - 'vision_analyze': '👁️', 'mixture_of_agents': '🧠', - 'skills_list': '📚', 'skill_view': '📚', - 'schedule_cronjob': '⏰', 'list_cronjobs': '⏰', 'remove_cronjob': '⏰', - 'send_message': '📨', 'todo': '📋', 'memory': '🧠', 'session_search': '🔍', - 'clarify': '❓', 'execute_code': '🐍', 'delegate_task': '🔀', + "web_search": "🔍", + "web_extract": "📄", + "web_crawl": "🕸️", + "terminal": "💻", + "process": "⚙️", + "read_file": "📖", + "write_file": "✍️", + "patch": "🔧", + "search_files": "🔎", + "browser_navigate": "🌐", + "browser_snapshot": "📸", + "browser_click": "👆", + "browser_type": "⌨️", + "browser_scroll": "📜", + "browser_back": "◀️", + "browser_press": "⌨️", + "browser_close": "🚪", + "browser_get_images": "🖼️", + "browser_vision": "👁️", + "image_generate": "🎨", + "text_to_speech": "🔊", + "vision_analyze": "👁️", + "mixture_of_agents": "🧠", + "skills_list": "📚", + "skill_view": "📚", + "schedule_cronjob": "⏰", + "list_cronjobs": "⏰", + "remove_cronjob": "⏰", + "send_message": "📨", + "todo": "📋", + "memory": "🧠", + "session_search": "🔍", + "clarify": "❓", + "execute_code": "🐍", + "delegate_task": "🔀", } - emoji = tool_emoji_map.get(function_name, '⚡') + emoji = tool_emoji_map.get(function_name, "⚡") preview = _build_tool_preview(function_name, function_args) or function_name if len(preview) > 30: preview = preview[:27] + "..." - spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots') + spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type="dots") spinner.start() _spinner_result = None try: @@ -2821,7 +2881,9 @@ class AIAgent: logger.error("handle_function_call raised for %s: %s", function_name, tool_error, exc_info=True) finally: tool_duration = time.time() - tool_start_time - cute_msg = _get_cute_tool_message_impl(function_name, function_args, tool_duration, result=_spinner_result) + cute_msg = _get_cute_tool_message_impl( + function_name, function_args, tool_duration, result=_spinner_result + ) spinner.stop(cute_msg) else: try: @@ -2856,16 +2918,16 @@ class AIAgent: f"exceeding the {MAX_TOOL_RESULT_CHARS:,} char limit]" ) - tool_msg = { - "role": "tool", - "content": function_result, - "tool_call_id": tool_call.id - } + tool_msg = {"role": "tool", "content": function_result, "tool_call_id": tool_call.id} messages.append(tool_msg) self._log_msg_to_db(tool_msg) if not self.quiet_mode: - response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result + response_preview = ( + function_result[: self.log_prefix_chars] + "..." + if len(function_result) > self.log_prefix_chars + else function_result + ) print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}") if self._interrupt_requested and i < len(assistant_message.tool_calls): @@ -2876,7 +2938,7 @@ class AIAgent: skip_msg = { "role": "tool", "content": f"[Tool execution skipped — {skipped_name} was not started. User sent a new message]", - "tool_call_id": skipped_tc.id + "tool_call_id": skipped_tc.id, } messages.append(skip_msg) self._log_msg_to_db(skip_msg) @@ -2923,10 +2985,7 @@ class AIAgent: if self.reasoning_config is not None: summary_extra_body["reasoning"] = self.reasoning_config else: - summary_extra_body["reasoning"] = { - "enabled": True, - "effort": "medium" - } + summary_extra_body["reasoning"] = {"enabled": True, "effort": "medium"} if _is_nous: summary_extra_body["tags"] = ["product=hermes-agent"] @@ -2969,7 +3028,7 @@ class AIAgent: if final_response: if "" in final_response: - final_response = re.sub(r'.*?\s*', '', final_response, flags=re.DOTALL).strip() + final_response = re.sub(r".*?\s*", "", final_response, flags=re.DOTALL).strip() if final_response: messages.append({"role": "assistant", "content": final_response}) else: @@ -3001,7 +3060,7 @@ class AIAgent: if final_response: if "" in final_response: - final_response = re.sub(r'.*?\s*', '', final_response, flags=re.DOTALL).strip() + final_response = re.sub(r".*?\s*", "", final_response, flags=re.DOTALL).strip() if final_response: messages.append({"role": "assistant", "content": final_response}) else: @@ -3011,7 +3070,9 @@ class AIAgent: except Exception as e: logging.warning(f"Failed to get summary response: {e}") - final_response = f"I reached the maximum iterations ({self.max_iterations}) but couldn't summarize. Error: {str(e)}" + final_response = ( + f"I reached the maximum iterations ({self.max_iterations}) but couldn't summarize. Error: {str(e)}" + ) return final_response @@ -3019,9 +3080,9 @@ class AIAgent: self, user_message: str, system_message: str = None, - conversation_history: List[Dict[str, Any]] = None, - task_id: str = None - ) -> Dict[str, Any]: + conversation_history: list[dict[str, Any]] = None, + task_id: str = None, + ) -> dict[str, Any]: """ Run a complete conversation with tool calling until completion. @@ -3036,7 +3097,7 @@ class AIAgent: """ # Generate unique task_id if not provided to isolate VMs between concurrent tasks effective_task_id = task_id or str(uuid.uuid4()) - + # Reset retry counters and iteration budget at the start of each turn # so subagent usage from a previous turn doesn't eat into the next one. self._invalid_tool_retries = 0 @@ -3046,21 +3107,21 @@ class AIAgent: self._turns_since_memory = 0 self._iters_since_skill = 0 self.iteration_budget = IterationBudget(self.max_iterations) - + # Initialize conversation (copy to avoid mutating the caller's list) messages = list(conversation_history) if conversation_history else [] - + # Hydrate todo store from conversation history (gateway creates a fresh # AIAgent per message, so the in-memory store is empty -- we need to # recover the todo state from the most recent todo tool response in history) if conversation_history and not self._todo_store.has_items(): self._hydrate_todo_store(conversation_history) - + # Prefill messages (few-shot priming) are injected at API-call time only, # never stored in the messages list. This keeps them ephemeral: they won't # be saved to session DB, session logs, or batch trajectories, but they're # automatically re-applied on every API call (including session continuations). - + # Track user turns for memory flush and periodic nudge logic self._user_turn_count += 1 @@ -3070,9 +3131,7 @@ class AIAgent: # Periodic memory nudge: remind the model to consider saving memories. # Counter resets whenever the memory tool is actually used. - if (self._memory_nudge_interval > 0 - and "memory" in self.valid_tool_names - and self._memory_store): + if self._memory_nudge_interval > 0 and "memory" in self.valid_tool_names and self._memory_store: self._turns_since_memory += 1 if self._turns_since_memory >= self._memory_nudge_interval: user_message += ( @@ -3083,9 +3142,11 @@ class AIAgent: # Skill creation nudge: fires on the first user message after a long tool loop. # The counter increments per API iteration in the tool loop and is checked here. - if (self._skill_nudge_interval > 0 - and self._iters_since_skill >= self._skill_nudge_interval - and "skill_manage" in self.valid_tool_names): + if ( + self._skill_nudge_interval > 0 + and self._iters_since_skill >= self._skill_nudge_interval + and "skill_manage" in self.valid_tool_names + ): user_message += ( "\n\n[System: The previous task involved many steps. " "If you discovered a reusable workflow, consider saving it as a skill.]" @@ -3109,10 +3170,10 @@ class AIAgent: user_msg = {"role": "user", "content": user_message} messages.append(user_msg) self._log_msg_to_db(user_msg) - + if not self.quiet_mode: print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'") - + # ── System prompt (cached per session for prefix caching) ── # Built once on first call, reused for all subsequent calls. # Only rebuilt after context compression events (which invalidate @@ -3144,9 +3205,7 @@ class AIAgent: # Bake Honcho context into the prompt so it's stable for # the entire session (not re-fetched per turn). if self._honcho_context: - self._cached_system_prompt = ( - self._cached_system_prompt + "\n\n" + self._honcho_context - ).strip() + self._cached_system_prompt = (self._cached_system_prompt + "\n\n" + self._honcho_context).strip() # Store the system prompt snapshot in SQLite if self._session_db: try: @@ -3165,8 +3224,7 @@ class AIAgent: # 4xx and abort the request entirely). if ( self.compression_enabled - and len(messages) > self.context_compressor.protect_first_n - + self.context_compressor.protect_last_n + 1 + and len(messages) > self.context_compressor.protect_first_n + self.context_compressor.protect_last_n + 1 ): _sys_tok_est = estimate_tokens_rough(active_system_prompt or "") _msg_tok_est = estimate_messages_tokens_rough(messages) @@ -3206,22 +3264,24 @@ class AIAgent: final_response = None interrupted = False codex_ack_continuations = 0 - + # Clear any stale interrupt state at start self.clear_interrupt() - + while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0: # Check for interrupt request (e.g., user sent new message) if self._interrupt_requested: interrupted = True if not self.quiet_mode: - print(f"\n⚡ Breaking out of tool loop due to interrupt...") + print("\n⚡ Breaking out of tool loop due to interrupt...") break - + api_call_count += 1 if not self.iteration_budget.consume(): if not self.quiet_mode: - print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)") + print( + f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)" + ) break # Fire step_callback for gateway hooks (agent:step event) @@ -3230,11 +3290,7 @@ class AIAgent: prev_tools = [] for _m in reversed(messages): if _m.get("role") == "assistant" and _m.get("tool_calls"): - prev_tools = [ - tc["function"]["name"] - for tc in _m["tool_calls"] - if isinstance(tc, dict) - ] + prev_tools = [tc["function"]["name"] for tc in _m["tool_calls"] if isinstance(tc, dict)] break self.step_callback(api_call_count, prev_tools) except Exception as _step_err: @@ -3242,10 +3298,9 @@ class AIAgent: # Track tool-calling iterations for skill nudge. # Counter resets whenever skill_manage is actually used. - if (self._skill_nudge_interval > 0 - and "skill_manage" in self.valid_tool_names): + if self._skill_nudge_interval > 0 and "skill_manage" in self.valid_tool_names: self._iters_since_skill += 1 - + # Prepare messages for API call # If we have an ephemeral system prompt, prepend it to the messages # Note: Reasoning is embedded in content via tags for trajectory storage. @@ -3254,7 +3309,7 @@ class AIAgent: api_messages = [] for msg in messages: api_msg = msg.copy() - + # For ALL assistant messages, pass reasoning back to the API # This ensures multi-turn reasoning context is preserved if msg.get("role") == "assistant": @@ -3262,7 +3317,7 @@ class AIAgent: if reasoning_text: # Add reasoning_content for API compatibility (Moonshot AI, Novita, OpenRouter) api_msg["reasoning_content"] = reasoning_text - + # Remove 'reasoning' field - it's for trajectory storage only # We've copied it to 'reasoning_content' for the API above if "reasoning" in api_msg: @@ -3273,7 +3328,7 @@ class AIAgent: # Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context # The signature field helps maintain reasoning continuity api_messages.append(api_msg) - + # Build the final system message: cached prompt + ephemeral system prompt. # The ephemeral part is appended here (not baked into the cached prompt) # so it stays out of the session DB and logs. @@ -3286,53 +3341,57 @@ class AIAgent: effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip() if effective_system: api_messages = [{"role": "system", "content": effective_system}] + api_messages - + # Inject ephemeral prefill messages right after the system prompt # but before conversation history. Same API-call-time-only pattern. if self.prefill_messages: sys_offset = 1 if effective_system else 0 for idx, pfm in enumerate(self.prefill_messages): api_messages.insert(sys_offset + idx, pfm.copy()) - + # Apply Anthropic prompt caching for Claude models via OpenRouter. # Auto-detected: if model name contains "claude" and base_url is OpenRouter, # inject cache_control breakpoints (system + last 3 messages) to reduce # input token costs by ~75% on multi-turn conversations. if self._use_prompt_caching: api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl) - + # Safety net: strip orphaned tool results / add stubs for missing # results before sending to the API. The compressor handles this # during compression, but orphans can also sneak in from session # loading or manual message manipulation. - if hasattr(self, 'context_compressor') and self.context_compressor: + if hasattr(self, "context_compressor") and self.context_compressor: api_messages = self.context_compressor._sanitize_tool_pairs(api_messages) # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token - + # Thinking spinner for quiet mode (animated during API call) thinking_spinner = None - + if not self.quiet_mode: print(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...") - print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)") + print( + f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)" + ) print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}") else: # Animated thinking spinner in quiet mode face = random.choice(KawaiiSpinner.KAWAII_THINKING) verb = random.choice(KawaiiSpinner.THINKING_VERBS) - spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star']) + spinner_type = random.choice(["brain", "sparkle", "pulse", "moon", "star"]) thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type) thinking_spinner.start() - + # Log request details if verbose if self.verbose_logging: - logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}") + logging.debug( + f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}" + ) logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}") logging.debug(f"Total message size: ~{approx_tokens:,} tokens") - + api_start_time = time.time() retry_count = 0 max_retries = 6 # Increased to allow longer backoff periods @@ -3354,23 +3413,25 @@ class AIAgent: self._dump_api_request_debug(api_kwargs, reason="preflight") response = self._interruptible_api_call(api_kwargs) - + api_duration = time.time() - api_start_time - + # Stop thinking spinner silently -- the response box or tool # execution messages that follow are more informative. if thinking_spinner: thinking_spinner.stop("") thinking_spinner = None - + if not self.quiet_mode: print(f"{self.log_prefix}⏱️ API call completed in {api_duration:.2f}s") - + if self.verbose_logging: # Log response with provider info if available - resp_model = getattr(response, 'model', 'N/A') if response else 'N/A' - logging.debug(f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}") - + resp_model = getattr(response, "model", "N/A") if response else "N/A" + logging.debug( + f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}" + ) + # Validate response shape before proceeding response_invalid = False error_details = [] @@ -3386,11 +3447,16 @@ class AIAgent: response_invalid = True error_details.append("response.output is empty") else: - if response is None or not hasattr(response, 'choices') or response.choices is None or len(response.choices) == 0: + if ( + response is None + or not hasattr(response, "choices") + or response.choices is None + or len(response.choices) == 0 + ): response_invalid = True if response is None: error_details.append("response is None") - elif not hasattr(response, 'choices'): + elif not hasattr(response, "choices"): error_details.append("response has no 'choices' attribute") elif response.choices is None: error_details.append("response.choices is None") @@ -3400,45 +3466,51 @@ class AIAgent: if response_invalid: # Stop spinner before printing error messages if thinking_spinner: - thinking_spinner.stop(f"(´;ω;`) oops, retrying...") + thinking_spinner.stop("(´;ω;`) oops, retrying...") thinking_spinner = None - + # This is often rate limiting or provider returning malformed response retry_count += 1 - + # Check for error field in response (some providers include this) error_msg = "Unknown" provider_name = "Unknown" - if response and hasattr(response, 'error') and response.error: + if response and hasattr(response, "error") and response.error: error_msg = str(response.error) # Try to extract provider from error metadata - if hasattr(response.error, 'metadata') and response.error.metadata: - provider_name = response.error.metadata.get('provider_name', 'Unknown') - elif response and hasattr(response, 'message') and response.message: + if hasattr(response.error, "metadata") and response.error.metadata: + provider_name = response.error.metadata.get("provider_name", "Unknown") + elif response and hasattr(response, "message") and response.message: error_msg = str(response.message) - + # Try to get provider from model field (OpenRouter often returns actual model used) - if provider_name == "Unknown" and response and hasattr(response, 'model') and response.model: + if provider_name == "Unknown" and response and hasattr(response, "model") and response.model: provider_name = f"model={response.model}" - + # Check for x-openrouter-provider or similar metadata if provider_name == "Unknown" and response: # Log all response attributes for debugging - resp_attrs = {k: str(v)[:100] for k, v in vars(response).items() if not k.startswith('_')} + resp_attrs = {k: str(v)[:100] for k, v in vars(response).items() if not k.startswith("_")} if self.verbose_logging: logging.debug(f"Response attributes for invalid response: {resp_attrs}") - - print(f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}") + + print( + f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}" + ) print(f"{self.log_prefix} 🏢 Provider: {provider_name}") print(f"{self.log_prefix} 📝 Provider message: {error_msg[:200]}") - print(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)") - + print( + f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)" + ) + if retry_count >= max_retries: # Try fallback before giving up if self._try_activate_fallback(): retry_count = 0 continue - print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.") + print( + f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up." + ) logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.") self._persist_session(messages, conversation_history) return { @@ -3446,14 +3518,18 @@ class AIAgent: "completed": False, "api_calls": api_call_count, "error": "Invalid API response shape. Likely rate limited or malformed provider response.", - "failed": True # Mark as failure for filtering + "failed": True, # Mark as failure for filtering } - + # Longer backoff for rate limiting (likely cause of None choices) wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s - print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...") - logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") - + print( + f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)..." + ) + logging.warning( + f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}" + ) + # Sleep in small increments to stay responsive to interrupts sleep_end = time.time() + wait_time while time.time() < sleep_end: @@ -3486,26 +3562,28 @@ class AIAgent: finish_reason = "stop" else: finish_reason = response.choices[0].finish_reason - + # Handle "length" finish_reason - response was truncated if finish_reason == "length": - print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens") - + print( + f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens" + ) + # If we have prior messages, roll back to last complete state if len(messages) > 1: print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn") rolled_back_messages = self._get_messages_up_to_last_assistant(messages) - + self._cleanup_task_resources(effective_task_id) self._persist_session(messages, conversation_history) - + return { "final_response": None, "messages": rolled_back_messages, "api_calls": api_call_count, "completed": False, "partial": True, - "error": "Response truncated due to output length limit" + "error": "Response truncated due to output length limit", } else: # First message was truncated - mark as failed @@ -3517,22 +3595,21 @@ class AIAgent: "api_calls": api_call_count, "completed": False, "failed": True, - "error": "First response truncated due to output length limit" + "error": "First response truncated due to output length limit", } - + # Track actual token usage from response for context management - if hasattr(response, 'usage') and response.usage: + if hasattr(response, "usage") and response.usage: if self.api_mode == "codex_responses": - prompt_tokens = getattr(response.usage, 'input_tokens', 0) or 0 - completion_tokens = getattr(response.usage, 'output_tokens', 0) or 0 - total_tokens = ( - getattr(response.usage, 'total_tokens', None) - or (prompt_tokens + completion_tokens) + prompt_tokens = getattr(response.usage, "input_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "output_tokens", 0) or 0 + total_tokens = getattr(response.usage, "total_tokens", None) or ( + prompt_tokens + completion_tokens ) else: - prompt_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0 - completion_tokens = getattr(response.usage, 'completion_tokens', 0) or 0 - total_tokens = getattr(response.usage, 'total_tokens', 0) or 0 + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + total_tokens = getattr(response.usage, "total_tokens", 0) or 0 usage_dict = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, @@ -3551,20 +3628,24 @@ class AIAgent: self.session_completion_tokens += completion_tokens self.session_total_tokens += total_tokens self.session_api_calls += 1 - + if self.verbose_logging: - logging.debug(f"Token usage: prompt={usage_dict['prompt_tokens']:,}, completion={usage_dict['completion_tokens']:,}, total={usage_dict['total_tokens']:,}") - + logging.debug( + f"Token usage: prompt={usage_dict['prompt_tokens']:,}, completion={usage_dict['completion_tokens']:,}, total={usage_dict['total_tokens']:,}" + ) + # Log cache hit stats when prompt caching is active if self._use_prompt_caching: - details = getattr(response.usage, 'prompt_tokens_details', None) - cached = getattr(details, 'cached_tokens', 0) or 0 if details else 0 - written = getattr(details, 'cache_write_tokens', 0) or 0 if details else 0 + details = getattr(response.usage, "prompt_tokens_details", None) + cached = getattr(details, "cached_tokens", 0) or 0 if details else 0 + written = getattr(details, "cache_write_tokens", 0) or 0 if details else 0 prompt = usage_dict["prompt_tokens"] hit_pct = (cached / prompt * 100) if prompt > 0 else 0 if not self.quiet_mode: - print(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)") - + print( + f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)" + ) + break # Success, exit retry loop except InterruptedError: @@ -3581,7 +3662,7 @@ class AIAgent: except Exception as api_error: # Stop spinner before printing error messages if thinking_spinner: - thinking_spinner.stop(f"(╥_╥) error, retrying...") + thinking_spinner.stop("(╥_╥) error, retrying...") thinking_spinner = None status_code = getattr(api_error, "status_code", None) @@ -3608,16 +3689,18 @@ class AIAgent: retry_count += 1 elapsed_time = time.time() - api_start_time - + # Enhanced error logging error_type = type(api_error).__name__ error_msg = str(api_error).lower() - + print(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}") print(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s") print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}") - print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools") - + print( + f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools" + ) + # Check for interrupt before deciding to retry if self._interrupt_requested: print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.") @@ -3630,32 +3713,38 @@ class AIAgent: "completed": False, "interrupted": True, } - + # Check for 413 payload-too-large BEFORE generic 4xx handler. # A 413 is a payload-size error — the correct response is to # compress history and retry, not abort immediately. status_code = getattr(api_error, "status_code", None) is_payload_too_large = ( status_code == 413 - or 'request entity too large' in error_msg - or 'payload too large' in error_msg - or 'error code: 413' in error_msg + or "request entity too large" in error_msg + or "payload too large" in error_msg + or "error code: 413" in error_msg ) if is_payload_too_large: compression_attempts += 1 if compression_attempts > max_compression_attempts: - print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.") - logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.") + print( + f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error." + ) + logging.error( + f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts." + ) self._persist_session(messages, conversation_history) return { "messages": messages, "completed": False, "api_calls": api_call_count, "error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.", - "partial": True + "partial": True, } - print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...") + print( + f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}..." + ) original_len = len(messages) messages, active_system_prompt = self._compress_context( @@ -3663,7 +3752,9 @@ class AIAgent: ) if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + print( + f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying..." + ) time.sleep(2) # Brief pause between compression retries continue # Retry with compressed messages else: @@ -3675,20 +3766,28 @@ class AIAgent: "completed": False, "api_calls": api_call_count, "error": "Request payload too large (413). Cannot compress further.", - "partial": True + "partial": True, } # Check for context-length errors BEFORE generic 4xx handler. # Local backends (LM Studio, Ollama, llama.cpp) often return # HTTP 400 with messages like "Context size has been exceeded" # which must trigger compression, not an immediate abort. - is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'context size', 'maximum context', - 'token limit', 'too many tokens', 'reduce the length', - 'exceeds the limit', 'context window', - 'request entity too large', # OpenRouter/Nous 413 safety net - ]) - + is_context_length_error = any( + phrase in error_msg + for phrase in [ + "context length", + "context size", + "maximum context", + "token limit", + "too many tokens", + "reduce the length", + "exceeds the limit", + "context window", + "request entity too large", # OpenRouter/Nous 413 safety net + ] + ) + if is_context_length_error: compressor = self.context_compressor old_ctx = compressor.context_length @@ -3697,7 +3796,9 @@ class AIAgent: parsed_limit = parse_context_limit_from_error(error_msg) if parsed_limit and parsed_limit < old_ctx: new_ctx = parsed_limit - print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})") + print( + f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})" + ) else: # Step down to the next probe tier new_ctx = get_next_probe_tier(old_ctx) @@ -3706,23 +3807,31 @@ class AIAgent: compressor.context_length = new_ctx compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) compressor._context_probed = True - print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens") + print( + f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens" + ) else: - print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...") + print( + f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression..." + ) compression_attempts += 1 if compression_attempts > max_compression_attempts: print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.") - logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.") + logging.error( + f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts." + ) self._persist_session(messages, conversation_history) return { "messages": messages, "completed": False, "api_calls": api_call_count, "error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.", - "partial": True + "partial": True, } - print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...") + print( + f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}..." + ) original_len = len(messages) messages, active_system_prompt = self._compress_context( @@ -3731,35 +3840,55 @@ class AIAgent: if len(messages) < original_len or new_ctx and new_ctx < old_ctx: if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + print( + f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying..." + ) time.sleep(2) # Brief pause between compression retries continue # Retry with compressed messages or new tier else: # Can't compress further and already at minimum tier print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.") print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.") - logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") + logging.error( + f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further." + ) self._persist_session(messages, conversation_history) return { "messages": messages, "completed": False, "api_calls": api_call_count, "error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.", - "partial": True + "partial": True, } # Check for non-retryable client errors (4xx HTTP status codes). # These indicate a problem with the request itself (bad model ID, # invalid API key, forbidden, etc.) and will never succeed on retry. # Note: 413 and context-length errors are excluded — handled above. - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413 - is_client_error = (is_client_status_error or any(phrase in error_msg for phrase in [ - 'error code: 401', 'error code: 403', - 'error code: 404', 'error code: 422', - 'is not a valid model', 'invalid model', 'model not found', - 'invalid api key', 'invalid_api_key', 'authentication', - 'unauthorized', 'forbidden', 'not found', - ])) and not is_context_length_error + is_client_status_error = ( + isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413 + ) + is_client_error = ( + is_client_status_error + or any( + phrase in error_msg + for phrase in [ + "error code: 401", + "error code: 403", + "error code: 404", + "error code: 422", + "is not a valid model", + "invalid model", + "model not found", + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "forbidden", + "not found", + ] + ) + ) and not is_context_length_error if is_client_error: # Try fallback before aborting — a different provider @@ -3768,7 +3897,9 @@ class AIAgent: retry_count = 0 continue self._dump_api_request_debug( - api_kwargs, reason="non_retryable_client_error", error=api_error, + api_kwargs, + reason="non_retryable_client_error", + error=api_error, ) print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.") print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.") @@ -3789,16 +3920,22 @@ class AIAgent: retry_count = 0 continue print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.") - logging.error(f"{self.log_prefix}API call failed after {max_retries} retries. Last error: {api_error}") - logging.error(f"{self.log_prefix}Request details - Messages: {len(api_messages)}, Approx tokens: {approx_tokens:,}") + logging.error( + f"{self.log_prefix}API call failed after {max_retries} retries. Last error: {api_error}" + ) + logging.error( + f"{self.log_prefix}Request details - Messages: {len(api_messages)}, Approx tokens: {approx_tokens:,}" + ) raise api_error - wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s + wait_time = min(2**retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}") if retry_count >= max_retries: - print(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}") + print( + f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}" + ) print(f"{self.log_prefix}⏳ Final retry in {wait_time}s...") - + # Sleep in small increments so we can respond to interrupts quickly # instead of blocking the entire wait_time in one sleep() call sleep_end = time.time() + wait_time @@ -3815,7 +3952,7 @@ class AIAgent: "interrupted": True, } time.sleep(0.2) # Check interrupt every 200ms - + # If the API call was interrupted, skip response processing if interrupted: break @@ -3833,7 +3970,7 @@ class AIAgent: assistant_message, finish_reason = self._normalize_codex_response(response) else: assistant_message = response.choices[0].message - + # Normalize content to string — some OpenAI-compatible servers # (llama-server, etc.) return content as a dict or list instead # of a plain string, which crashes downstream .strip() calls. @@ -3857,35 +3994,38 @@ class AIAgent: # Handle assistant response if assistant_message.content and not self.quiet_mode: - print(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}") + print( + f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}" + ) # Notify progress callback of model's thinking (used by subagent # delegation to relay the child's reasoning to the parent display). # Guard: only fire for subagents (_delegate_depth >= 1) to avoid # spamming gateway platforms with the main agent's every thought. - if (assistant_message.content and self.tool_progress_callback - and getattr(self, '_delegate_depth', 0) > 0): + if ( + assistant_message.content + and self.tool_progress_callback + and getattr(self, "_delegate_depth", 0) > 0 + ): _think_text = assistant_message.content.strip() # Strip reasoning XML tags that shouldn't leak to parent display - _think_text = re.sub( - r'', '', _think_text - ).strip() - first_line = _think_text.split('\n')[0][:80] if _think_text else "" + _think_text = re.sub(r"", "", _think_text).strip() + first_line = _think_text.split("\n")[0][:80] if _think_text else "" if first_line: try: self.tool_progress_callback("_thinking", first_line) except Exception: pass - + # Check for incomplete (opened but never closed) # This means the model ran out of output tokens mid-reasoning — retry up to 2 times if has_incomplete_scratchpad(assistant_message.content or ""): - if not hasattr(self, '_incomplete_scratchpad_retries'): + if not hasattr(self, "_incomplete_scratchpad_retries"): self._incomplete_scratchpad_retries = 0 self._incomplete_scratchpad_retries += 1 - + print(f"{self.log_prefix}⚠️ Incomplete detected (opened but never closed)") - + if self._incomplete_scratchpad_retries <= 2: print(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...") # Don't add the broken message, just retry @@ -3894,22 +4034,22 @@ class AIAgent: # Max retries - discard this turn and save as partial print(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.") self._incomplete_scratchpad_retries = 0 - + rolled_back_messages = self._get_messages_up_to_last_assistant(messages) self._cleanup_task_resources(effective_task_id) self._persist_session(messages, conversation_history) - + return { "final_response": None, "messages": rolled_back_messages, "api_calls": api_call_count, "completed": False, "partial": True, - "error": "Incomplete REASONING_SCRATCHPAD after 2 retries" + "error": "Incomplete REASONING_SCRATCHPAD after 2 retries", } - + # Reset incomplete scratchpad counter on clean response - if hasattr(self, '_incomplete_scratchpad_retries'): + if hasattr(self, "_incomplete_scratchpad_retries"): self._incomplete_scratchpad_retries = 0 if self.api_mode == "codex_responses" and finish_reason == "incomplete": @@ -3919,7 +4059,11 @@ class AIAgent: interim_msg = self._build_assistant_message(assistant_message, finish_reason) interim_has_content = bool((interim_msg.get("content") or "").strip()) - interim_has_reasoning = bool(interim_msg.get("reasoning", "").strip()) if isinstance(interim_msg.get("reasoning"), str) else False + interim_has_reasoning = ( + bool(interim_msg.get("reasoning", "").strip()) + if isinstance(interim_msg.get("reasoning"), str) + else False + ) if interim_has_content or interim_has_reasoning: last_msg = messages[-1] if messages else None @@ -3936,7 +4080,9 @@ class AIAgent: if self._codex_incomplete_retries < 3: if not self.quiet_mode: - print(f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)") + print( + f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)" + ) self._session_messages = messages self._save_session_log(messages) continue @@ -3953,38 +4099,45 @@ class AIAgent: } elif hasattr(self, "_codex_incomplete_retries"): self._codex_incomplete_retries = 0 - + # Check for tool calls if assistant_message.tool_calls: if not self.quiet_mode: print(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...") - + if self.verbose_logging: for tc in assistant_message.tool_calls: logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...") - + # Validate tool call names - detect model hallucinations invalid_tool_calls = [ - tc.function.name for tc in assistant_message.tool_calls + tc.function.name + for tc in assistant_message.tool_calls if tc.function.name not in self.valid_tool_names ] - + if invalid_tool_calls: # Track retries for invalid tool calls - if not hasattr(self, '_invalid_tool_retries'): + if not hasattr(self, "_invalid_tool_retries"): self._invalid_tool_retries = 0 self._invalid_tool_retries += 1 - - invalid_preview = invalid_tool_calls[0][:80] + "..." if len(invalid_tool_calls[0]) > 80 else invalid_tool_calls[0] + + invalid_preview = ( + invalid_tool_calls[0][:80] + "..." + if len(invalid_tool_calls[0]) > 80 + else invalid_tool_calls[0] + ) print(f"{self.log_prefix}⚠️ Invalid tool call detected: '{invalid_preview}'") print(f"{self.log_prefix} Valid tools: {sorted(self.valid_tool_names)}") - + if self._invalid_tool_retries < 3: print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_tool_retries}/3)...") # Don't add anything to messages, just retry the API call continue else: - print(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.") + print( + f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial." + ) # Return partial result - don't include the bad tool call in messages self._invalid_tool_retries = 0 self._persist_session(messages, conversation_history) @@ -3994,13 +4147,13 @@ class AIAgent: "api_calls": api_call_count, "completed": False, "partial": True, - "error": f"Model generated invalid tool call: {invalid_preview}" + "error": f"Model generated invalid tool call: {invalid_preview}", } - + # Reset retry counter on successful tool call validation - if hasattr(self, '_invalid_tool_retries'): + if hasattr(self, "_invalid_tool_retries"): self._invalid_tool_retries = 0 - + # Validate tool call arguments are valid JSON # Handle empty strings as empty objects (common model quirk) invalid_json_args = [] @@ -4014,14 +4167,14 @@ class AIAgent: json.loads(args) except json.JSONDecodeError as e: invalid_json_args.append((tc.function.name, str(e))) - + if invalid_json_args: # Track retries for invalid JSON arguments self._invalid_json_retries += 1 - + tool_name, error_msg = invalid_json_args[0] print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") - + if self._invalid_json_retries < 3: print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") # Don't add anything to messages, just retry the API call @@ -4030,7 +4183,7 @@ class AIAgent: # Instead of returning partial, inject a helpful message and let model recover print(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...") self._invalid_json_retries = 0 # Reset for next attempt - + # Add a user message explaining the issue recovery_msg = ( f"Your tool call to '{tool_name}' had invalid JSON arguments. " @@ -4042,12 +4195,12 @@ class AIAgent: messages.append(recovery_dict) self._log_msg_to_db(recovery_dict) continue - + # Reset retry counter on successful JSON validation self._invalid_json_retries = 0 - + assistant_msg = self._build_assistant_message(assistant_message, finish_reason) - + # If this turn has both content AND tool_calls, capture the content # as a fallback final response. Common pattern: model delivers its # answer and calls memory/skill tools as a side-effect in the same @@ -4060,10 +4213,10 @@ class AIAgent: clean = self._strip_think_blocks(turn_content).strip() if clean: print(f" ┊ 💬 {clean}") - + messages.append(assistant_msg) self._log_msg_to_db(assistant_msg) - + self._execute_tool_calls(assistant_message, messages, effective_task_id) # Refund the iteration if the ONLY tool(s) called were @@ -4072,33 +4225,34 @@ class AIAgent: _tc_names = {tc.function.name for tc in assistant_message.tool_calls} if _tc_names == {"execute_code"}: self.iteration_budget.refund() - + if self.compression_enabled and self.context_compressor.should_compress(): messages, active_system_prompt = self._compress_context( - messages, system_message, - approx_tokens=self.context_compressor.last_prompt_tokens + messages, system_message, approx_tokens=self.context_compressor.last_prompt_tokens ) - + # Save session log incrementally (so progress is visible even if interrupted) self._session_messages = messages self._save_session_log(messages) - + # Continue loop for next response continue - + else: # No tool calls - this is the final response final_response = assistant_message.content or "" - + # Check if response only has think block with no actual content after it if not self._has_content_after_think_block(final_response): # If the previous turn already delivered real content alongside # tool calls (e.g. "You're welcome!" + memory save), the model # has nothing more to say. Use the earlier content immediately # instead of wasting API calls on retries that won't help. - fallback = getattr(self, '_last_content_with_tools', None) + fallback = getattr(self, "_last_content_with_tools", None) if fallback: - logger.debug("Empty follow-up after tool calls — using prior turn content as final response") + logger.debug( + "Empty follow-up after tool calls — using prior turn content as final response" + ) self._last_content_with_tools = None self._empty_content_retries = 0 for i in range(len(messages) - 1, -1, -1): @@ -4108,37 +4262,43 @@ class AIAgent: for tc in msg["tool_calls"]: fn = tc.get("function", {}) tool_names.append(fn.get("name", "unknown")) - msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..." + msg["content"] = ( + f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..." + ) break final_response = self._strip_think_blocks(fallback).strip() break # No fallback available — this is a genuine empty response. # Retry in case the model just had a bad generation. - if not hasattr(self, '_empty_content_retries'): + if not hasattr(self, "_empty_content_retries"): self._empty_content_retries = 0 self._empty_content_retries += 1 - + reasoning_text = self._extract_reasoning(assistant_message) print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it") if reasoning_text: - reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text + reasoning_preview = ( + reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text + ) print(f"{self.log_prefix} Reasoning: {reasoning_preview}") else: - content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response + content_preview = ( + final_response[:80] + "..." if len(final_response) > 80 else final_response + ) print(f"{self.log_prefix} Content: '{content_preview}'") - + if self._empty_content_retries < 3: print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...") continue else: print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.") self._empty_content_retries = 0 - + # If a prior tool_calls turn had real content, salvage it: # rewrite that turn's content to a brief tool description, # and use the original content as the final response here. - fallback = getattr(self, '_last_content_with_tools', None) + fallback = getattr(self, "_last_content_with_tools", None) if fallback: self._last_content_with_tools = None # Find the last assistant message with tool_calls and rewrite it @@ -4149,12 +4309,14 @@ class AIAgent: for tc in msg["tool_calls"]: fn = tc.get("function", {}) tool_names.append(fn.get("name", "unknown")) - msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..." + msg["content"] = ( + f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..." + ) break # Strip blocks from fallback content for user display final_response = self._strip_think_blocks(fallback).strip() break - + # No fallback -- append the empty message as-is empty_msg = { "role": "assistant", @@ -4164,21 +4326,21 @@ class AIAgent: } messages.append(empty_msg) self._log_msg_to_db(empty_msg) - + self._cleanup_task_resources(effective_task_id) self._persist_session(messages, conversation_history) - + return { "final_response": final_response or None, "messages": messages, "api_calls": api_call_count, "completed": False, "partial": True, - "error": "Model generated only think blocks with no actual response after 3 retries" + "error": "Model generated only think blocks with no actual response after 3 retries", } - + # Reset retry counter on successful content - if hasattr(self, '_empty_content_retries'): + if hasattr(self, "_empty_content_retries"): self._empty_content_retries = 0 if ( @@ -4210,26 +4372,26 @@ class AIAgent: continue codex_ack_continuations = 0 - + # Strip blocks from user-facing response (keep raw in messages for trajectory) final_response = self._strip_think_blocks(final_response).strip() - + final_msg = self._build_assistant_message(assistant_message, finish_reason) - + messages.append(final_msg) self._log_msg_to_db(final_msg) - + if not self.quiet_mode: print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)") break - + except Exception as e: error_msg = f"Error during OpenAI-compatible API call #{api_call_count}: {str(e)}" print(f"❌ {error_msg}") - + if self.verbose_logging: logging.exception("Detailed error information:") - + # If an assistant message with tool_calls was already appended, # the API expects a role="tool" result for every tool_call_id. # Fill in error results for any that weren't answered yet. @@ -4243,7 +4405,7 @@ class AIAgent: if msg.get("role") == "assistant" and msg.get("tool_calls"): answered_ids = { m["tool_call_id"] - for m in messages[idx + 1:] + for m in messages[idx + 1 :] if isinstance(m, dict) and m.get("role") == "tool" } for tc in msg["tool_calls"]: @@ -4257,7 +4419,7 @@ class AIAgent: self._log_msg_to_db(err_msg) pending_handled = True break - + if not pending_handled: # Error happened before tool processing (e.g. response parsing). # Use a user-role message so the model can see what went wrong @@ -4268,20 +4430,19 @@ class AIAgent: } messages.append(sys_err_msg) self._log_msg_to_db(sys_err_msg) - + # If we're near the limit, break to avoid infinite loops if api_call_count >= self.max_iterations - 1: final_response = f"I apologize, but I encountered repeated errors: {error_msg}" break - - if final_response is None and ( - api_call_count >= self.max_iterations - or self.iteration_budget.remaining <= 0 - ): + + if final_response is None and (api_call_count >= self.max_iterations or self.iteration_budget.remaining <= 0): if self.iteration_budget.remaining <= 0 and not self.quiet_mode: - print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)") + print( + f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)" + ) final_response = self._handle_max_iterations(messages, api_call_count) - + # Determine if conversation completed successfully completed = final_response is not None and api_call_count < self.max_iterations @@ -4307,23 +4468,23 @@ class AIAgent: "partial": False, # True only when stopped due to invalid tool calls "interrupted": interrupted, } - + # Include interrupt message if one triggered the interrupt if interrupted and self._interrupt_message: result["interrupt_message"] = self._interrupt_message - + # Clear interrupt state after handling self.clear_interrupt() - + return result - + def chat(self, message: str) -> str: """ Simple chat interface that returns just the final response. - + Args: message (str): User message - + Returns: str: Final assistant response """ @@ -4343,7 +4504,7 @@ def main( save_trajectories: bool = False, save_sample: bool = False, verbose: bool = False, - log_prefix_chars: int = 20 + log_prefix_chars: int = 20, ): """ Main function for running the agent directly. @@ -4369,25 +4530,25 @@ def main( """ print("🤖 AI Agent with Tool Calling") print("=" * 50) - + # Handle tool listing if list_tools: - from model_tools import get_all_tool_names, get_toolset_for_tool, get_available_toolsets + from model_tools import get_all_tool_names, get_available_toolsets, get_toolset_for_tool from toolsets import get_all_toolsets, get_toolset_info - + print("📋 Available Tools & Toolsets:") print("-" * 50) - + # Show new toolsets system print("\n🎯 Predefined Toolsets (New System):") print("-" * 40) all_toolsets = get_all_toolsets() - + # Group by category basic_toolsets = [] composite_toolsets = [] scenario_toolsets = [] - + for name, toolset in all_toolsets.items(): info = get_toolset_info(name) if info: @@ -4398,29 +4559,28 @@ def main( composite_toolsets.append(entry) else: scenario_toolsets.append(entry) - + # Print basic toolsets print("\n📌 Basic Toolsets:") for name, info in basic_toolsets: - tools_str = ', '.join(info['resolved_tools']) if info['resolved_tools'] else 'none' + tools_str = ", ".join(info["resolved_tools"]) if info["resolved_tools"] else "none" print(f" • {name:15} - {info['description']}") print(f" Tools: {tools_str}") - + # Print composite toolsets print("\n📂 Composite Toolsets (built from other toolsets):") for name, info in composite_toolsets: - includes_str = ', '.join(info['includes']) if info['includes'] else 'none' + includes_str = ", ".join(info["includes"]) if info["includes"] else "none" print(f" • {name:15} - {info['description']}") print(f" Includes: {includes_str}") print(f" Total tools: {info['tool_count']}") - + # Print scenario-specific toolsets print("\n🎭 Scenario-Specific Toolsets:") for name, info in scenario_toolsets: print(f" • {name:20} - {info['description']}") print(f" Total tools: {info['tool_count']}") - - + # Show legacy toolset compatibility print("\n📦 Legacy Toolsets (for backward compatibility):") legacy_toolsets = get_available_toolsets() @@ -4429,47 +4589,47 @@ def main( print(f" {status} {name}: {info['description']}") if not info["available"]: print(f" Requirements: {', '.join(info['requirements'])}") - + # Show individual tools all_tools = get_all_tool_names() print(f"\n🔧 Individual Tools ({len(all_tools)} available):") for tool_name in sorted(all_tools): toolset = get_toolset_for_tool(tool_name) print(f" 📌 {tool_name} (from {toolset})") - - print(f"\n💡 Usage Examples:") - print(f" # Use predefined toolsets") - print(f" python run_agent.py --enabled_toolsets=research --query='search for Python news'") - print(f" python run_agent.py --enabled_toolsets=development --query='debug this code'") - print(f" python run_agent.py --enabled_toolsets=safe --query='analyze without terminal'") - print(f" ") - print(f" # Combine multiple toolsets") - print(f" python run_agent.py --enabled_toolsets=web,vision --query='analyze website'") - print(f" ") - print(f" # Disable toolsets") - print(f" python run_agent.py --disabled_toolsets=terminal --query='no command execution'") - print(f" ") - print(f" # Run with trajectory saving enabled") - print(f" python run_agent.py --save_trajectories --query='your question here'") + + print("\n💡 Usage Examples:") + print(" # Use predefined toolsets") + print(" python run_agent.py --enabled_toolsets=research --query='search for Python news'") + print(" python run_agent.py --enabled_toolsets=development --query='debug this code'") + print(" python run_agent.py --enabled_toolsets=safe --query='analyze without terminal'") + print(" ") + print(" # Combine multiple toolsets") + print(" python run_agent.py --enabled_toolsets=web,vision --query='analyze website'") + print(" ") + print(" # Disable toolsets") + print(" python run_agent.py --disabled_toolsets=terminal --query='no command execution'") + print(" ") + print(" # Run with trajectory saving enabled") + print(" python run_agent.py --save_trajectories --query='your question here'") return - + # Parse toolset selection arguments enabled_toolsets_list = None disabled_toolsets_list = None - + if enabled_toolsets: enabled_toolsets_list = [t.strip() for t in enabled_toolsets.split(",")] print(f"🎯 Enabled toolsets: {enabled_toolsets_list}") - + if disabled_toolsets: disabled_toolsets_list = [t.strip() for t in disabled_toolsets.split(",")] print(f"🚫 Disabled toolsets: {disabled_toolsets_list}") - + if save_trajectories: - print(f"💾 Trajectory saving: ENABLED") - print(f" - Successful conversations → trajectory_samples.jsonl") - print(f" - Failed conversations → failed_trajectories.jsonl") - + print("💾 Trajectory saving: ENABLED") + print(" - Successful conversations → trajectory_samples.jsonl") + print(" - Failed conversations → failed_trajectories.jsonl") + # Initialize agent with provided parameters try: agent = AIAgent( @@ -4481,12 +4641,12 @@ def main( disabled_toolsets=disabled_toolsets_list, save_trajectories=save_trajectories, verbose_logging=verbose, - log_prefix_chars=log_prefix_chars + log_prefix_chars=log_prefix_chars, ) except RuntimeError as e: print(f"❌ Failed to initialize agent: {e}") return - + # Use provided query or default to Python 3.13 example if query is None: user_query = ( @@ -4495,45 +4655,41 @@ def main( ) else: user_query = query - + print(f"\n📝 User Query: {user_query}") print("\n" + "=" * 50) - + # Run conversation result = agent.run_conversation(user_query) - + print("\n" + "=" * 50) print("📋 CONVERSATION SUMMARY") print("=" * 50) print(f"✅ Completed: {result['completed']}") print(f"📞 API Calls: {result['api_calls']}") print(f"💬 Messages: {len(result['messages'])}") - - if result['final_response']: - print(f"\n🎯 FINAL RESPONSE:") + + if result["final_response"]: + print("\n🎯 FINAL RESPONSE:") print("-" * 30) - print(result['final_response']) - + print(result["final_response"]) + # Save sample trajectory to UUID-named file if requested if save_sample: sample_id = str(uuid.uuid4())[:8] sample_filename = f"sample_{sample_id}.json" - + # Convert messages to trajectory format (same as batch_runner) - trajectory = agent._convert_to_trajectory_format( - result['messages'], - user_query, - result['completed'] - ) - + trajectory = agent._convert_to_trajectory_format(result["messages"], user_query, result["completed"]) + entry = { "conversations": trajectory, "timestamp": datetime.now().isoformat(), "model": model, - "completed": result['completed'], - "query": user_query + "completed": result["completed"], + "query": user_query, } - + try: with open(sample_filename, "w", encoding="utf-8") as f: # Pretty-print JSON with indent for readability @@ -4541,7 +4697,7 @@ def main( print(f"\n💾 Sample trajectory saved to: {sample_filename}") except Exception as e: print(f"\n⚠️ Failed to save sample: {e}") - + print("\n👋 Agent execution completed!") diff --git a/tools/__init__.py b/tools/__init__.py index 04eabd0235..b3f255033e 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -16,249 +16,222 @@ for the AI agent to access all capabilities. """ # Export all tools for easy importing -from .web_tools import ( - web_search_tool, - web_extract_tool, - web_crawl_tool, - check_firecrawl_api_key -) - -# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona) -from .terminal_tool import ( - terminal_tool, - check_terminal_requirements, - cleanup_vm, - cleanup_all_environments, - get_active_environments_info, - register_task_env_overrides, - clear_task_env_overrides, - TERMINAL_TOOL_DESCRIPTION -) - -from .vision_tools import ( - vision_analyze_tool, - check_vision_requirements -) - -from .mixture_of_agents_tool import ( - mixture_of_agents_tool, - check_moa_requirements -) - -from .image_generation_tool import ( - image_generate_tool, - check_image_generation_requirements -) - -from .skills_tool import ( - skills_list, - skill_view, - check_skills_requirements, - SKILLS_TOOL_DESCRIPTION -) - -from .skill_manager_tool import ( - skill_manage, - check_skill_manage_requirements, - SKILL_MANAGE_SCHEMA -) - # Browser automation tools (agent-browser + Browserbase) from .browser_tool import ( - browser_navigate, - browser_snapshot, - browser_click, - browser_type, - browser_scroll, + BROWSER_TOOL_SCHEMAS, browser_back, - browser_press, + browser_click, browser_close, browser_get_images, + browser_navigate, + browser_press, + browser_scroll, + browser_snapshot, + browser_type, browser_vision, - cleanup_browser, - cleanup_all_browsers, - get_active_browser_sessions, check_browser_requirements, - BROWSER_TOOL_SCHEMAS -) - -# Cronjob management tools (CLI-only, hermes-cli toolset) -from .cronjob_tools import ( - schedule_cronjob, - list_cronjobs, - remove_cronjob, - check_cronjob_requirements, - get_cronjob_tool_definitions, - SCHEDULE_CRONJOB_SCHEMA, - LIST_CRONJOBS_SCHEMA, - REMOVE_CRONJOB_SCHEMA -) - -# RL Training tools (Tinker-Atropos) -from .rl_training_tool import ( - rl_list_environments, - rl_select_environment, - rl_get_current_config, - rl_edit_config, - rl_start_training, - rl_check_status, - rl_stop_training, - rl_get_results, - rl_list_runs, - rl_test_inference, - check_rl_api_keys, - get_missing_keys, -) - -# File manipulation tools (read, write, patch, search) -from .file_tools import ( - read_file_tool, - write_file_tool, - patch_tool, - search_tool, - get_file_tools, - clear_file_ops_cache, -) - -# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI) -from .tts_tool import ( - text_to_speech_tool, - check_tts_requirements, -) - -# Planning & task management tool -from .todo_tool import ( - todo_tool, - check_todo_requirements, - TODO_SCHEMA, - TodoStore, + cleanup_all_browsers, + cleanup_browser, + get_active_browser_sessions, ) # Clarifying questions tool (interactive Q&A with the user) from .clarify_tool import ( - clarify_tool, - check_clarify_requirements, CLARIFY_SCHEMA, + check_clarify_requirements, + clarify_tool, ) # Code execution sandbox (programmatic tool calling) from .code_execution_tool import ( - execute_code, - check_sandbox_requirements, EXECUTE_CODE_SCHEMA, + check_sandbox_requirements, + execute_code, +) + +# Cronjob management tools (CLI-only, hermes-cli toolset) +from .cronjob_tools import ( + LIST_CRONJOBS_SCHEMA, + REMOVE_CRONJOB_SCHEMA, + SCHEDULE_CRONJOB_SCHEMA, + check_cronjob_requirements, + get_cronjob_tool_definitions, + list_cronjobs, + remove_cronjob, + schedule_cronjob, ) # Subagent delegation (spawn child agents with isolated context) from .delegate_tool import ( - delegate_task, - check_delegate_requirements, DELEGATE_TASK_SCHEMA, + check_delegate_requirements, + delegate_task, ) +# File manipulation tools (read, write, patch, search) +from .file_tools import ( + clear_file_ops_cache, + get_file_tools, + patch_tool, + read_file_tool, + search_tool, + write_file_tool, +) +from .image_generation_tool import check_image_generation_requirements, image_generate_tool +from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool + +# RL Training tools (Tinker-Atropos) +from .rl_training_tool import ( + check_rl_api_keys, + get_missing_keys, + rl_check_status, + rl_edit_config, + rl_get_current_config, + rl_get_results, + rl_list_environments, + rl_list_runs, + rl_select_environment, + rl_start_training, + rl_stop_training, + rl_test_inference, +) +from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage +from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list + +# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona) +from .terminal_tool import ( + TERMINAL_TOOL_DESCRIPTION, + check_terminal_requirements, + cleanup_all_environments, + cleanup_vm, + clear_task_env_overrides, + get_active_environments_info, + register_task_env_overrides, + terminal_tool, +) + +# Planning & task management tool +from .todo_tool import ( + TODO_SCHEMA, + TodoStore, + check_todo_requirements, + todo_tool, +) + +# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI) +from .tts_tool import ( + check_tts_requirements, + text_to_speech_tool, +) +from .vision_tools import check_vision_requirements, vision_analyze_tool +from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool + + # File tools have no external requirements - they use the terminal backend def check_file_requirements(): """File tools only require terminal backend to be available.""" from .terminal_tool import check_terminal_requirements + return check_terminal_requirements() + __all__ = [ # Web tools - 'web_search_tool', - 'web_extract_tool', - 'web_crawl_tool', - 'check_firecrawl_api_key', + "web_search_tool", + "web_extract_tool", + "web_crawl_tool", + "check_firecrawl_api_key", # Terminal tools (mini-swe-agent backend) - 'terminal_tool', - 'check_terminal_requirements', - 'cleanup_vm', - 'cleanup_all_environments', - 'get_active_environments_info', - 'register_task_env_overrides', - 'clear_task_env_overrides', - 'TERMINAL_TOOL_DESCRIPTION', + "terminal_tool", + "check_terminal_requirements", + "cleanup_vm", + "cleanup_all_environments", + "get_active_environments_info", + "register_task_env_overrides", + "clear_task_env_overrides", + "TERMINAL_TOOL_DESCRIPTION", # Vision tools - 'vision_analyze_tool', - 'check_vision_requirements', + "vision_analyze_tool", + "check_vision_requirements", # MoA tools - 'mixture_of_agents_tool', - 'check_moa_requirements', + "mixture_of_agents_tool", + "check_moa_requirements", # Image generation tools - 'image_generate_tool', - 'check_image_generation_requirements', + "image_generate_tool", + "check_image_generation_requirements", # Skills tools - 'skills_list', - 'skill_view', - 'check_skills_requirements', - 'SKILLS_TOOL_DESCRIPTION', + "skills_list", + "skill_view", + "check_skills_requirements", + "SKILLS_TOOL_DESCRIPTION", # Skill management - 'skill_manage', - 'check_skill_manage_requirements', - 'SKILL_MANAGE_SCHEMA', + "skill_manage", + "check_skill_manage_requirements", + "SKILL_MANAGE_SCHEMA", # Browser automation tools - 'browser_navigate', - 'browser_snapshot', - 'browser_click', - 'browser_type', - 'browser_scroll', - 'browser_back', - 'browser_press', - 'browser_close', - 'browser_get_images', - 'browser_vision', - 'cleanup_browser', - 'cleanup_all_browsers', - 'get_active_browser_sessions', - 'check_browser_requirements', - 'BROWSER_TOOL_SCHEMAS', + "browser_navigate", + "browser_snapshot", + "browser_click", + "browser_type", + "browser_scroll", + "browser_back", + "browser_press", + "browser_close", + "browser_get_images", + "browser_vision", + "cleanup_browser", + "cleanup_all_browsers", + "get_active_browser_sessions", + "check_browser_requirements", + "BROWSER_TOOL_SCHEMAS", # Cronjob management tools (CLI-only) - 'schedule_cronjob', - 'list_cronjobs', - 'remove_cronjob', - 'check_cronjob_requirements', - 'get_cronjob_tool_definitions', - 'SCHEDULE_CRONJOB_SCHEMA', - 'LIST_CRONJOBS_SCHEMA', - 'REMOVE_CRONJOB_SCHEMA', + "schedule_cronjob", + "list_cronjobs", + "remove_cronjob", + "check_cronjob_requirements", + "get_cronjob_tool_definitions", + "SCHEDULE_CRONJOB_SCHEMA", + "LIST_CRONJOBS_SCHEMA", + "REMOVE_CRONJOB_SCHEMA", # RL Training tools - 'rl_list_environments', - 'rl_select_environment', - 'rl_get_current_config', - 'rl_edit_config', - 'rl_start_training', - 'rl_check_status', - 'rl_stop_training', - 'rl_get_results', - 'rl_list_runs', - 'rl_test_inference', - 'check_rl_api_keys', - 'get_missing_keys', + "rl_list_environments", + "rl_select_environment", + "rl_get_current_config", + "rl_edit_config", + "rl_start_training", + "rl_check_status", + "rl_stop_training", + "rl_get_results", + "rl_list_runs", + "rl_test_inference", + "check_rl_api_keys", + "get_missing_keys", # File manipulation tools - 'read_file_tool', - 'write_file_tool', - 'patch_tool', - 'search_tool', - 'get_file_tools', - 'clear_file_ops_cache', - 'check_file_requirements', + "read_file_tool", + "write_file_tool", + "patch_tool", + "search_tool", + "get_file_tools", + "clear_file_ops_cache", + "check_file_requirements", # Text-to-speech tools - 'text_to_speech_tool', - 'check_tts_requirements', + "text_to_speech_tool", + "check_tts_requirements", # Planning & task management tool - 'todo_tool', - 'check_todo_requirements', - 'TODO_SCHEMA', - 'TodoStore', + "todo_tool", + "check_todo_requirements", + "TODO_SCHEMA", + "TodoStore", # Clarifying questions tool - 'clarify_tool', - 'check_clarify_requirements', - 'CLARIFY_SCHEMA', + "clarify_tool", + "check_clarify_requirements", + "CLARIFY_SCHEMA", # Code execution sandbox - 'execute_code', - 'check_sandbox_requirements', - 'EXECUTE_CODE_SCHEMA', + "execute_code", + "check_sandbox_requirements", + "EXECUTE_CODE_SCHEMA", # Subagent delegation - 'delegate_task', - 'check_delegate_requirements', - 'DELEGATE_TASK_SCHEMA', + "delegate_task", + "check_delegate_requirements", + "DELEGATE_TASK_SCHEMA", ] - diff --git a/tools/approval.py b/tools/approval.py index cdf19e4435..f2de5f651e 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -12,7 +12,6 @@ import os import re import sys import threading -from typing import Optional logger = logging.getLogger(__name__) @@ -21,32 +20,32 @@ logger = logging.getLogger(__name__) # ========================================================================= DANGEROUS_PATTERNS = [ - (r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"), - (r'\brm\s+-[^\s]*r', "recursive delete"), - (r'\brm\s+--recursive\b', "recursive delete (long flag)"), - (r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"), - (r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"), - (r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"), - (r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"), - (r'\bmkfs\b', "format filesystem"), - (r'\bdd\s+.*if=', "disk copy"), - (r'>\s*/dev/sd', "write to block device"), - (r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"), - (r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"), - (r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"), - (r'>\s*/etc/', "overwrite system config"), - (r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"), - (r'\bkill\s+-9\s+-1\b', "kill all processes"), - (r'\bpkill\s+-9\b', "force kill processes"), - (r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"), - (r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"), - (r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"), - (r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"), - (r'\b(bash|sh|zsh|ksh)\s+<\s*\s*/dev/sd", "write to block device"), + (r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"), + (r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"), + (r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"), + (r">\s*/etc/", "overwrite system config"), + (r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"), + (r"\bkill\s+-9\s+-1\b", "kill all processes"), + (r"\bpkill\s+-9\b", "force kill processes"), + (r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"), + (r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"), + (r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"), + (r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"), + (r"\b(bash|sh|zsh|ksh)\s+<\s* tuple: """Check if a command matches any dangerous patterns. @@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple: command_lower = command.lower() for pattern, description in DANGEROUS_PATTERNS: if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL): - pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20] + pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20] return (True, pattern_key, description) return (False, None, None) @@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict): _pending[session_key] = approval -def pop_pending(session_key: str) -> Optional[dict]: +def pop_pending(session_key: str) -> dict | None: """Retrieve and remove a pending approval for a session.""" with _lock: return _pending.pop(session_key, None) @@ -133,6 +133,7 @@ def clear_session(session_key: str): # Config persistence for permanent allowlist # ========================================================================= + def load_permanent_allowlist() -> set: """Load permanently allowed command patterns from config. @@ -141,6 +142,7 @@ def load_permanent_allowlist() -> set: """ try: from hermes_cli.config import load_config + config = load_config() patterns = set(config.get("command_allowlist", []) or []) if patterns: @@ -154,6 +156,7 @@ def save_permanent_allowlist(patterns: set): """Save permanently allowed command patterns to config.""" try: from hermes_cli.config import load_config, save_config + config = load_config() config["command_allowlist"] = list(patterns) save_config(config) @@ -165,9 +168,8 @@ def save_permanent_allowlist(patterns: set): # Approval prompting + orchestration # ========================================================================= -def prompt_dangerous_approval(command: str, description: str, - timeout_seconds: int = 60, - approval_callback=None) -> str: + +def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str: """Prompt the user to approve a dangerous command (CLI only). Args: @@ -188,7 +190,7 @@ def prompt_dangerous_approval(command: str, description: str, print(f" ⚠️ DANGEROUS COMMAND: {description}") print(f" {command[:80]}{'...' if len(command) > 80 else ''}") print() - print(f" [o]nce | [s]ession | [a]lways | [d]eny") + print(" [o]nce | [s]ession | [a]lways | [d]eny") print() sys.stdout.flush() @@ -209,13 +211,13 @@ def prompt_dangerous_approval(command: str, description: str, return "deny" choice = result["choice"] - if choice in ('o', 'once'): + if choice in ("o", "once"): print(" ✓ Allowed once") return "once" - elif choice in ('s', 'session'): + elif choice in ("s", "session"): print(" ✓ Allowed for this session") return "session" - elif choice in ('a', 'always'): + elif choice in ("a", "always"): print(" ✓ Added to permanent allowlist") return "always" else: @@ -232,8 +234,7 @@ def prompt_dangerous_approval(command: str, description: str, sys.stdout.flush() -def check_dangerous_command(command: str, env_type: str, - approval_callback=None) -> dict: +def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict: """Check if a command is dangerous and handle approval. This is the main entry point called by terminal_tool before executing @@ -265,11 +266,14 @@ def check_dangerous_command(command: str, env_type: str, return {"approved": True, "message": None} if is_gateway or os.getenv("HERMES_EXEC_ASK"): - submit_pending(session_key, { - "command": command, - "pattern_key": pattern_key, - "description": description, - }) + submit_pending( + session_key, + { + "command": command, + "pattern_key": pattern_key, + "description": description, + }, + ) return { "approved": False, "pattern_key": pattern_key, @@ -279,8 +283,7 @@ def check_dangerous_command(command: str, env_type: str, "message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...", } - choice = prompt_dangerous_approval(command, description, - approval_callback=approval_callback) + choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback) if choice == "deny": return { diff --git a/tools/browser_tool.py b/tools/browser_tool.py index 480093eaa5..dab3a7a43d 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -38,13 +38,13 @@ Environment Variables: Usage: from tools.browser_tool import browser_navigate, browser_snapshot, browser_click - + # Navigate to a page result = browser_navigate("https://example.com", task_id="task_123") - + # Get page snapshot snapshot = browser_snapshot(task_id="task_123") - + # Click an element browser_click("@e5", task_id="task_123") """ @@ -53,17 +53,19 @@ import atexit import json import logging import os +import shutil import signal import subprocess -import shutil import sys import tempfile import threading import time -import requests -from typing import Dict, Any, Optional, List from pathlib import Path -from agent.auxiliary_client import get_vision_auxiliary_client, get_text_auxiliary_client +from typing import Any + +import requests + +from agent.auxiliary_client import get_text_auxiliary_client, get_vision_auxiliary_client logger = logging.getLogger(__name__) @@ -102,16 +104,14 @@ EXTRACTION_MODEL = _DEFAULT_TEXT_MODEL or _DEFAULT_VISION_MODEL def _get_vision_model() -> str: """Model for browser_vision (screenshot analysis — multimodal).""" - return (os.getenv("AUXILIARY_VISION_MODEL", "").strip() - or _DEFAULT_VISION_MODEL - or "google/gemini-3-flash-preview") + return os.getenv("AUXILIARY_VISION_MODEL", "").strip() or _DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview" def _get_extraction_model() -> str: """Model for page snapshot text summarization — same as web_extract.""" - return (os.getenv("AUXILIARY_WEB_EXTRACT_MODEL", "").strip() - or _DEFAULT_TEXT_MODEL - or "google/gemini-3-flash-preview") + return ( + os.getenv("AUXILIARY_WEB_EXTRACT_MODEL", "").strip() or _DEFAULT_TEXT_MODEL or "google/gemini-3-flash-preview" + ) def _is_local_mode() -> bool: @@ -143,7 +143,7 @@ def _socket_safe_tmpdir() -> str: # Track active sessions per task # Stores: session_name (always), bb_session_id + cdp_url (cloud mode only) -_active_sessions: Dict[str, Dict[str, str]] = {} # task_id -> {session_name, ...} +_active_sessions: dict[str, dict[str, str]] = {} # task_id -> {session_name, ...} _recording_sessions: set = set() # task_ids with active recordings # Flag to track if cleanup has been done @@ -159,7 +159,7 @@ _cleanup_done = False BROWSER_SESSION_INACTIVITY_TIMEOUT = int(os.environ.get("BROWSER_INACTIVITY_TIMEOUT", "300")) # Track last activity time per session -_session_last_activity: Dict[str, float] = {} +_session_last_activity: dict[str, float] = {} # Background cleanup thread state _cleanup_thread = None @@ -178,12 +178,12 @@ def _emergency_cleanup_all_sessions(): if _cleanup_done: return _cleanup_done = True - + if not _active_sessions: return - + logger.info("Emergency cleanup: closing %s active session(s)...", len(_active_sessions)) - + try: if _is_local_mode(): # Local mode: just close agent-browser sessions via CLI @@ -192,14 +192,13 @@ def _emergency_cleanup_all_sessions(): if session_name: try: browser_cmd = _find_agent_browser() - task_socket_dir = os.path.join( - _socket_safe_tmpdir(), - f"agent-browser-{session_name}" - ) + task_socket_dir = os.path.join(_socket_safe_tmpdir(), f"agent-browser-{session_name}") env = {**os.environ, "AGENT_BROWSER_SOCKET_DIR": task_socket_dir} subprocess.run( browser_cmd.split() + ["--session", session_name, "--json", "close"], - capture_output=True, timeout=5, env=env, + capture_output=True, + timeout=5, + env=env, ) logger.info("Closed local session %s", session_name) except Exception as e: @@ -219,15 +218,9 @@ def _emergency_cleanup_all_sessions(): try: response = requests.post( f"https://api.browserbase.com/v1/sessions/{bb_session_id}", - headers={ - "X-BB-API-Key": api_key, - "Content-Type": "application/json" - }, - json={ - "projectId": project_id, - "status": "REQUEST_RELEASE" - }, - timeout=5 # Short timeout for cleanup + headers={"X-BB-API-Key": api_key, "Content-Type": "application/json"}, + json={"projectId": project_id, "status": "REQUEST_RELEASE"}, + timeout=5, # Short timeout for cleanup ) if response.status_code in (200, 201, 204): logger.info("Closed session %s", bb_session_id) @@ -235,7 +228,7 @@ def _emergency_cleanup_all_sessions(): logger.warning("Failed to close session %s: HTTP %s", bb_session_id, response.status_code) except Exception as e: logger.error("Error closing session %s: %s", bb_session_id, e) - + _active_sessions.clear() except Exception as e: logger.error("Emergency cleanup error: %s", e) @@ -264,22 +257,23 @@ except (OSError, AttributeError): # Inactivity Cleanup Functions # ============================================================================= + def _cleanup_inactive_browser_sessions(): """ Clean up browser sessions that have been inactive for longer than the timeout. - + This function is called periodically by the background cleanup thread to automatically close sessions that haven't been used recently, preventing orphaned sessions (local or Browserbase) from accumulating. """ current_time = time.time() sessions_to_cleanup = [] - + with _cleanup_lock: for task_id, last_time in list(_session_last_activity.items()): if current_time - last_time > BROWSER_SESSION_INACTIVITY_TIMEOUT: sessions_to_cleanup.append(task_id) - + for task_id in sessions_to_cleanup: try: elapsed = int(current_time - _session_last_activity.get(task_id, current_time)) @@ -295,18 +289,18 @@ def _cleanup_inactive_browser_sessions(): def _browser_cleanup_thread_worker(): """ Background thread that periodically cleans up inactive browser sessions. - + Runs every 30 seconds and checks for sessions that haven't been used within the BROWSER_SESSION_INACTIVITY_TIMEOUT period. """ global _cleanup_running - + while _cleanup_running: try: _cleanup_inactive_browser_sessions() except Exception as e: logger.warning("Cleanup thread error: %s", e) - + # Sleep in 1-second intervals so we can stop quickly if needed for _ in range(30): if not _cleanup_running: @@ -317,14 +311,12 @@ def _browser_cleanup_thread_worker(): def _start_browser_cleanup_thread(): """Start the background cleanup thread if not already running.""" global _cleanup_thread, _cleanup_running - + with _cleanup_lock: if _cleanup_thread is None or not _cleanup_thread.is_alive(): _cleanup_running = True _cleanup_thread = threading.Thread( - target=_browser_cleanup_thread_worker, - daemon=True, - name="browser-cleanup" + target=_browser_cleanup_thread_worker, daemon=True, name="browser-cleanup" ) _cleanup_thread.start() logger.info("Started inactivity cleanup thread (timeout: %ss)", BROWSER_SESSION_INACTIVITY_TIMEOUT) @@ -359,13 +351,10 @@ BROWSER_TOOL_SCHEMAS = [ "parameters": { "type": "object", "properties": { - "url": { - "type": "string", - "description": "The URL to navigate to (e.g., 'https://example.com')" - } + "url": {"type": "string", "description": "The URL to navigate to (e.g., 'https://example.com')"} }, - "required": ["url"] - } + "required": ["url"], + }, }, { "name": "browser_snapshot", @@ -376,11 +365,11 @@ BROWSER_TOOL_SCHEMAS = [ "full": { "type": "boolean", "description": "If true, returns complete page content. If false (default), returns compact view with interactive elements only.", - "default": False + "default": False, } }, - "required": [] - } + "required": [], + }, }, { "name": "browser_click", @@ -390,11 +379,11 @@ BROWSER_TOOL_SCHEMAS = [ "properties": { "ref": { "type": "string", - "description": "The element reference from the snapshot (e.g., '@e5', '@e12')" + "description": "The element reference from the snapshot (e.g., '@e5', '@e12')", } }, - "required": ["ref"] - } + "required": ["ref"], + }, }, { "name": "browser_type", @@ -402,17 +391,11 @@ BROWSER_TOOL_SCHEMAS = [ "parameters": { "type": "object", "properties": { - "ref": { - "type": "string", - "description": "The element reference from the snapshot (e.g., '@e3')" - }, - "text": { - "type": "string", - "description": "The text to type into the field" - } + "ref": {"type": "string", "description": "The element reference from the snapshot (e.g., '@e3')"}, + "text": {"type": "string", "description": "The text to type into the field"}, }, - "required": ["ref", "text"] - } + "required": ["ref", "text"], + }, }, { "name": "browser_scroll", @@ -420,23 +403,15 @@ BROWSER_TOOL_SCHEMAS = [ "parameters": { "type": "object", "properties": { - "direction": { - "type": "string", - "enum": ["up", "down"], - "description": "Direction to scroll" - } + "direction": {"type": "string", "enum": ["up", "down"], "description": "Direction to scroll"} }, - "required": ["direction"] - } + "required": ["direction"], + }, }, { "name": "browser_back", "description": "Navigate back to the previous page in browser history. Requires browser_navigate to be called first.", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } + "parameters": {"type": "object", "properties": {}, "required": []}, }, { "name": "browser_press", @@ -444,31 +419,20 @@ BROWSER_TOOL_SCHEMAS = [ "parameters": { "type": "object", "properties": { - "key": { - "type": "string", - "description": "Key to press (e.g., 'Enter', 'Tab', 'Escape', 'ArrowDown')" - } + "key": {"type": "string", "description": "Key to press (e.g., 'Enter', 'Tab', 'Escape', 'ArrowDown')"} }, - "required": ["key"] - } + "required": ["key"], + }, }, { "name": "browser_close", "description": "Close the browser session and release resources. Call this when done with browser tasks to free up Browserbase session quota.", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } + "parameters": {"type": "object", "properties": {}, "required": []}, }, { "name": "browser_get_images", "description": "Get a list of all images on the current page with their URLs and alt text. Useful for finding images to analyze with the vision tool. Requires browser_navigate to be called first.", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } + "parameters": {"type": "object", "properties": {}, "required": []}, }, { "name": "browser_vision", @@ -478,16 +442,16 @@ BROWSER_TOOL_SCHEMAS = [ "properties": { "question": { "type": "string", - "description": "What you want to know about the page visually. Be specific about what you're looking for." + "description": "What you want to know about the page visually. Be specific about what you're looking for.", }, "annotate": { "type": "boolean", "default": False, - "description": "If true, overlay numbered [N] labels on interactive elements. Each [N] maps to ref @eN for subsequent browser commands. Useful for QA and spatial reasoning about page layout." - } + "description": "If true, overlay numbered [N] labels on interactive elements. Each [N] maps to ref @eN for subsequent browser commands. Useful for QA and spatial reasoning about page layout.", + }, }, - "required": ["question"] - } + "required": ["question"], + }, }, { "name": "browser_console", @@ -498,11 +462,11 @@ BROWSER_TOOL_SCHEMAS = [ "clear": { "type": "boolean", "default": False, - "description": "If true, clear the message buffers after reading" + "description": "If true, clear the message buffers after reading", } }, - "required": [] - } + "required": [], + }, }, ] @@ -511,31 +475,31 @@ BROWSER_TOOL_SCHEMAS = [ # Utility Functions # ============================================================================ -def _create_browserbase_session(task_id: str) -> Dict[str, str]: + +def _create_browserbase_session(task_id: str) -> dict[str, str]: """ Create a Browserbase session with stealth features. - + Browserbase Stealth Modes: - Basic Stealth: ALWAYS enabled automatically. Generates random fingerprints, viewports, and solves visual CAPTCHAs. No configuration needed. - Advanced Stealth: Uses custom Chromium build for better bot detection avoidance. Requires Scale Plan. Enable via BROWSERBASE_ADVANCED_STEALTH=true. - + Proxies are enabled by default to route traffic through residential IPs, which significantly improves CAPTCHA solving rates. Can be disabled via BROWSERBASE_PROXIES=false if needed. - + Args: task_id: Unique identifier for the task - + Returns: Dict with session_name, bb_session_id, cdp_url, and feature flags """ import uuid - import sys - + config = _get_browserbase_config() - + # Check for optional settings from environment # Proxies: enabled by default for better CAPTCHA solving enable_proxies = os.environ.get("BROWSERBASE_PROXIES", "true").lower() != "false" @@ -545,7 +509,7 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: enable_keep_alive = os.environ.get("BROWSERBASE_KEEP_ALIVE", "true").lower() != "false" # Custom session timeout in milliseconds (optional) - extends session beyond project default custom_timeout_ms = os.environ.get("BROWSERBASE_SESSION_TIMEOUT") - + # Track which features are actually enabled for logging/debugging features_enabled = { "basic_stealth": True, # Always on @@ -554,18 +518,18 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: "keep_alive": False, "custom_timeout": False, } - + # Build session configuration # Note: Basic stealth mode is ALWAYS active - no configuration needed session_config = { "projectId": config["project_id"], } - + # Enable keepAlive for session reconnection (default: true, requires paid plan) # Allows reconnecting to the same session after network hiccups if enable_keep_alive: session_config["keepAlive"] = True - + # Add custom timeout if specified (in milliseconds) # This extends session duration beyond project's default timeout if custom_timeout_ms: @@ -575,19 +539,19 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: session_config["timeout"] = timeout_val except ValueError: logger.warning("Invalid BROWSERBASE_SESSION_TIMEOUT value: %s", custom_timeout_ms) - + # Enable proxies for better CAPTCHA solving (default: true) # Routes traffic through residential IPs for more reliable access if enable_proxies: session_config["proxies"] = True - + # Add advanced stealth if enabled (requires Scale Plan) # Uses custom Chromium build to avoid bot detection altogether if enable_advanced_stealth: session_config["browserSettings"] = { "advancedStealth": True, } - + # Create session via Browserbase API response = requests.post( "https://api.browserbase.com/v1/sessions", @@ -596,21 +560,23 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: "X-BB-API-Key": config["api_key"], }, json=session_config, - timeout=30 + timeout=30, ) - + # Track if we fell back from paid features proxies_fallback = False keepalive_fallback = False - + # Handle 402 Payment Required - likely paid features not available # Try to identify which feature caused the issue and retry without it if response.status_code == 402: # First try without keepAlive (most likely culprit for paid plan requirement) if enable_keep_alive: keepalive_fallback = True - logger.warning("keepAlive may require paid plan (402), retrying without it. " - "Sessions may timeout during long operations.") + logger.warning( + "keepAlive may require paid plan (402), retrying without it. " + "Sessions may timeout during long operations." + ) session_config.pop("keepAlive", None) response = requests.post( "https://api.browserbase.com/v1/sessions", @@ -619,14 +585,13 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: "X-BB-API-Key": config["api_key"], }, json=session_config, - timeout=30 + timeout=30, ) - + # If still 402, try without proxies too if response.status_code == 402 and enable_proxies: proxies_fallback = True - logger.warning("Proxies unavailable (402), retrying without proxies. " - "Bot detection may be less effective.") + logger.warning("Proxies unavailable (402), retrying without proxies. Bot detection may be less effective.") session_config.pop("proxies", None) response = requests.post( "https://api.browserbase.com/v1/sessions", @@ -635,15 +600,15 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: "X-BB-API-Key": config["api_key"], }, json=session_config, - timeout=30 + timeout=30, ) - + if not response.ok: raise RuntimeError(f"Failed to create Browserbase session: {response.status_code} {response.text}") - + session_data = response.json() session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}" - + # Update features based on what actually succeeded if enable_proxies and not proxies_fallback: features_enabled["proxies"] = True @@ -653,11 +618,11 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: features_enabled["keep_alive"] = True if custom_timeout_ms and "timeout" in session_config: features_enabled["custom_timeout"] = True - + # Log session info for debugging feature_str = ", ".join(k for k, v in features_enabled.items() if v) logger.info("Created session %s with features: %s", session_name, feature_str) - + return { "session_name": session_name, "bb_session_id": session_data["id"], @@ -666,71 +631,72 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]: } -def _create_local_session(task_id: str) -> Dict[str, str]: +def _create_local_session(task_id: str) -> dict[str, str]: """Create a lightweight local browser session (no cloud API call). Returns the same dict shape as ``_create_browserbase_session`` so the rest of the code can treat both modes uniformly. """ import uuid + session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}" logger.info("Created local browser session %s", session_name) return { "session_name": session_name, - "bb_session_id": None, # Not applicable in local mode - "cdp_url": None, # Not applicable in local mode + "bb_session_id": None, # Not applicable in local mode + "cdp_url": None, # Not applicable in local mode "features": {"local": True}, } -def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]: +def _get_session_info(task_id: str | None = None) -> dict[str, str]: """ Get or create session info for the given task. - + In cloud mode, creates a Browserbase session with proxies enabled. In local mode, generates a session name for agent-browser --session. Also starts the inactivity cleanup thread and updates activity tracking. Thread-safe: multiple subagents can call this concurrently. - + Args: task_id: Unique identifier for the task - + Returns: Dict with session_name (always), bb_session_id + cdp_url (cloud only) """ if task_id is None: task_id = "default" - + # Start the cleanup thread if not running (handles inactivity timeouts) _start_browser_cleanup_thread() - + # Update activity timestamp for this session _update_session_activity(task_id) - + with _cleanup_lock: # Check if we already have a session for this task if task_id in _active_sessions: return _active_sessions[task_id] - + # Create session outside the lock (network call in cloud mode) if _is_local_mode(): session_info = _create_local_session(task_id) else: session_info = _create_browserbase_session(task_id) - + with _cleanup_lock: _active_sessions[task_id] = session_info - + return session_info -def _get_session_name(task_id: Optional[str] = None) -> str: +def _get_session_name(task_id: str | None = None) -> str: """ Get the session name for agent-browser CLI. - + Args: task_id: Unique identifier for the task - + Returns: Session name for agent-browser """ @@ -738,40 +704,37 @@ def _get_session_name(task_id: Optional[str] = None) -> str: return session_info["session_name"] -def _get_browserbase_config() -> Dict[str, str]: +def _get_browserbase_config() -> dict[str, str]: """ Get Browserbase configuration from environment. - + Returns: Dict with api_key and project_id - + Raises: ValueError: If required env vars are not set """ api_key = os.environ.get("BROWSERBASE_API_KEY") project_id = os.environ.get("BROWSERBASE_PROJECT_ID") - + if not api_key or not project_id: raise ValueError( "BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment variables are required. " "Get your credentials at https://browserbase.com" ) - - return { - "api_key": api_key, - "project_id": project_id - } + + return {"api_key": api_key, "project_id": project_id} def _find_agent_browser() -> str: """ Find the agent-browser CLI executable. - + Checks in order: PATH, local node_modules/.bin/, npx fallback. - + Returns: Path to agent-browser executable - + Raises: FileNotFoundError: If agent-browser is not installed """ @@ -780,18 +743,18 @@ def _find_agent_browser() -> str: which_result = shutil.which("agent-browser") if which_result: return which_result - + # Check local node_modules/.bin/ (npm install in repo root) repo_root = Path(__file__).parent.parent local_bin = repo_root / "node_modules" / ".bin" / "agent-browser" if local_bin.exists(): return str(local_bin) - + # Check common npx locations npx_path = shutil.which("npx") if npx_path: return "npx agent-browser" - + raise FileNotFoundError( "agent-browser CLI not found. Install it with: npm install -g agent-browser\n" "Or run 'npm install' in the repo root to install locally.\n" @@ -800,33 +763,31 @@ def _find_agent_browser() -> str: def _run_browser_command( - task_id: str, - command: str, - args: List[str] = None, - timeout: int = DEFAULT_COMMAND_TIMEOUT -) -> Dict[str, Any]: + task_id: str, command: str, args: list[str] = None, timeout: int = DEFAULT_COMMAND_TIMEOUT +) -> dict[str, Any]: """ Run an agent-browser CLI command using our pre-created Browserbase session. - + Args: task_id: Task identifier to get the right session command: The command to run (e.g., "open", "click") args: Additional arguments for the command timeout: Command timeout in seconds - + Returns: Parsed JSON response from agent-browser """ args = args or [] - + # Build the command try: browser_cmd = _find_agent_browser() except FileNotFoundError as e: logger.warning("agent-browser CLI not found: %s", e) return {"success": False, "error": str(e)} - + from tools.interrupt import is_interrupted + if is_interrupted(): return {"success": False, "error": "Interrupted"} @@ -836,7 +797,7 @@ def _run_browser_command( except Exception as e: logger.warning("Failed to create browser session for task=%s: %s", task_id, e) return {"success": False, "error": f"Failed to create browser session: {str(e)}"} - + # Build the command with the appropriate backend flag. # Cloud mode: --cdp connects to Browserbase. # Local mode: --session launches a local headless Chromium. @@ -850,30 +811,25 @@ def _run_browser_command( # Local mode — launch a headless Chromium instance backend_args = ["--session", session_info["session_name"]] - cmd_parts = browser_cmd.split() + backend_args + [ - "--json", - command - ] + args - + cmd_parts = browser_cmd.split() + backend_args + ["--json", command] + args + try: # Give each task its own socket directory to prevent concurrency conflicts. # Without this, parallel workers fight over the same default socket path, # causing "Failed to create socket directory: Permission denied" errors. - task_socket_dir = os.path.join( - _socket_safe_tmpdir(), - f"agent-browser-{session_info['session_name']}" - ) + task_socket_dir = os.path.join(_socket_safe_tmpdir(), f"agent-browser-{session_info['session_name']}") os.makedirs(task_socket_dir, mode=0o700, exist_ok=True) - logger.debug("browser cmd=%s task=%s socket_dir=%s (%d chars)", - command, task_id, task_socket_dir, len(task_socket_dir)) - + logger.debug( + "browser cmd=%s task=%s socket_dir=%s (%d chars)", command, task_id, task_socket_dir, len(task_socket_dir) + ) + browser_env = {**os.environ} # Ensure PATH includes standard dirs (systemd services may have minimal PATH) _SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" if "/usr/bin" not in browser_env.get("PATH", "").split(":"): browser_env["PATH"] = f"{browser_env.get('PATH', '')}:{_SANE_PATH}" browser_env["AGENT_BROWSER_SOCKET_DIR"] = task_socket_dir - + result = subprocess.run( cmd_parts, capture_output=True, @@ -881,18 +837,20 @@ def _run_browser_command( timeout=timeout, env=browser_env, ) - + # Log stderr for diagnostics — use warning level on failure so it's visible if result.stderr and result.stderr.strip(): level = logging.WARNING if result.returncode != 0 else logging.DEBUG logger.log(level, "browser '%s' stderr: %s", command, result.stderr.strip()[:500]) - + # Log empty output as warning — common sign of broken agent-browser if not result.stdout.strip() and result.returncode == 0: - logger.warning("browser '%s' returned empty stdout with rc=0. " - "cmd=%s stderr=%s", - command, " ".join(cmd_parts[:4]) + "...", - (result.stderr or "")[:200]) + logger.warning( + "browser '%s' returned empty stdout with rc=0. cmd=%s stderr=%s", + command, + " ".join(cmd_parts[:4]) + "...", + (result.stderr or "")[:200], + ) # Parse JSON output if result.stdout.strip(): @@ -902,41 +860,40 @@ def _run_browser_command( if command == "snapshot" and parsed.get("success"): snap_data = parsed.get("data", {}) if not snap_data.get("snapshot") and not snap_data.get("refs"): - logger.warning("snapshot returned empty content. " - "Possible stale daemon or CDP connection issue. " - "returncode=%s", result.returncode) + logger.warning( + "snapshot returned empty content. " + "Possible stale daemon or CDP connection issue. " + "returncode=%s", + result.returncode, + ) return parsed except json.JSONDecodeError: # Non-JSON output indicates agent-browser crash or version mismatch raw = result.stdout.strip()[:500] - logger.warning("browser '%s' returned non-JSON output (rc=%s): %s", - command, result.returncode, raw[:200]) - return { - "success": True, - "data": {"raw": raw} - } - + logger.warning( + "browser '%s' returned non-JSON output (rc=%s): %s", command, result.returncode, raw[:200] + ) + return {"success": True, "data": {"raw": raw}} + # Check for errors if result.returncode != 0: error_msg = result.stderr.strip() if result.stderr else f"Command failed with code {result.returncode}" logger.warning("browser '%s' failed (rc=%s): %s", command, result.returncode, error_msg[:300]) return {"success": False, "error": error_msg} - + return {"success": True, "data": {}} - + except subprocess.TimeoutExpired: - logger.warning("browser '%s' timed out after %ds (task=%s, socket_dir=%s)", - command, timeout, task_id, task_socket_dir) + logger.warning( + "browser '%s' timed out after %ds (task=%s, socket_dir=%s)", command, timeout, task_id, task_socket_dir + ) return {"success": False, "error": f"Command timed out after {timeout} seconds"} except Exception as e: logger.warning("browser '%s' exception: %s", command, e, exc_info=True) return {"success": False, "error": str(e)} -def _extract_relevant_content( - snapshot_text: str, - user_task: Optional[str] = None -) -> str: +def _extract_relevant_content(snapshot_text: str, user_task: str | None = None) -> str: """Use LLM to extract relevant content from a snapshot based on the user's task. Falls back to simple truncation when no auxiliary text model is configured. @@ -969,6 +926,7 @@ def _extract_relevant_content( try: from agent.auxiliary_client import auxiliary_max_tokens_param + response = _aux_text_client.chat.completions.create( model=_get_extraction_model(), messages=[{"role": "user", "content": extraction_prompt}], @@ -983,17 +941,17 @@ def _extract_relevant_content( def _truncate_snapshot(snapshot_text: str, max_chars: int = 8000) -> str: """ Simple truncation fallback for snapshots. - + Args: snapshot_text: The snapshot text to truncate max_chars: Maximum characters to keep - + Returns: Truncated text with indicator if truncated """ if len(snapshot_text) <= max_chars: return snapshot_text - + return snapshot_text[:max_chars] + "\n\n[... content truncated ...]" @@ -1001,52 +959,57 @@ def _truncate_snapshot(snapshot_text: str, max_chars: int = 8000) -> str: # Browser Tool Functions # ============================================================================ -def browser_navigate(url: str, task_id: Optional[str] = None) -> str: + +def browser_navigate(url: str, task_id: str | None = None) -> str: """ Navigate to a URL in the browser. - + Args: url: The URL to navigate to task_id: Task identifier for session isolation - + Returns: JSON string with navigation result (includes stealth features info on first nav) """ effective_task_id = task_id or "default" - + # Get session info to check if this is a new session # (will create one with features logged if not exists) session_info = _get_session_info(effective_task_id) is_first_nav = session_info.get("_first_nav", True) - + # Auto-start recording if configured and this is first navigation if is_first_nav: session_info["_first_nav"] = False _maybe_start_recording(effective_task_id) - + result = _run_browser_command(effective_task_id, "open", [url], timeout=60) - + if result.get("success"): data = result.get("data", {}) title = data.get("title", "") final_url = data.get("url", url) - - response = { - "success": True, - "url": final_url, - "title": title - } - + + response = {"success": True, "url": final_url, "title": title} + # Detect common "blocked" page patterns from title/url blocked_patterns = [ - "access denied", "access to this page has been denied", - "blocked", "bot detected", "verification required", - "please verify", "are you a robot", "captcha", - "cloudflare", "ddos protection", "checking your browser", - "just a moment", "attention required" + "access denied", + "access to this page has been denied", + "blocked", + "bot detected", + "verification required", + "please verify", + "are you a robot", + "captcha", + "cloudflare", + "ddos protection", + "checking your browser", + "just a moment", + "attention required", ] title_lower = title.lower() - + if any(pattern in title_lower for pattern in blocked_patterns): response["bot_detection_warning"] = ( f"Page title '{title}' suggests bot detection. The site may have blocked this request. " @@ -1054,7 +1017,7 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str: "3) Enable advanced stealth (BROWSERBASE_ADVANCED_STEALTH=true, requires Scale plan), " "4) Some sites have very aggressive bot detection that may be unavoidable." ) - + # Include feature info on first navigation so model knows what's active if is_first_nav and "features" in session_info: features = session_info["features"] @@ -1065,233 +1028,197 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str: "Consider upgrading Browserbase plan for proxy support." ) response["stealth_features"] = active_features - + return json.dumps(response, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", "Navigation failed") - }, ensure_ascii=False) + return json.dumps({"success": False, "error": result.get("error", "Navigation failed")}, ensure_ascii=False) -def browser_snapshot( - full: bool = False, - task_id: Optional[str] = None, - user_task: Optional[str] = None -) -> str: +def browser_snapshot(full: bool = False, task_id: str | None = None, user_task: str | None = None) -> str: """ Get a text-based snapshot of the current page's accessibility tree. - + Args: full: If True, return complete snapshot. If False, return compact view. task_id: Task identifier for session isolation user_task: The user's current task (for task-aware extraction) - + Returns: JSON string with page snapshot """ effective_task_id = task_id or "default" - + # Build command args based on full flag args = [] if not full: args.extend(["-c"]) # Compact mode - + result = _run_browser_command(effective_task_id, "snapshot", args) - + if result.get("success"): data = result.get("data", {}) snapshot_text = data.get("snapshot", "") refs = data.get("refs", {}) - + # Check if snapshot needs summarization if len(snapshot_text) > SNAPSHOT_SUMMARIZE_THRESHOLD and user_task: snapshot_text = _extract_relevant_content(snapshot_text, user_task) elif len(snapshot_text) > SNAPSHOT_SUMMARIZE_THRESHOLD: snapshot_text = _truncate_snapshot(snapshot_text) - - response = { - "success": True, - "snapshot": snapshot_text, - "element_count": len(refs) if refs else 0 - } - + + response = {"success": True, "snapshot": snapshot_text, "element_count": len(refs) if refs else 0} + return json.dumps(response, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", "Failed to get snapshot") - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": result.get("error", "Failed to get snapshot")}, ensure_ascii=False + ) -def browser_click(ref: str, task_id: Optional[str] = None) -> str: +def browser_click(ref: str, task_id: str | None = None) -> str: """ Click on an element. - + Args: ref: Element reference (e.g., "@e5") task_id: Task identifier for session isolation - + Returns: JSON string with click result """ effective_task_id = task_id or "default" - + # Ensure ref starts with @ if not ref.startswith("@"): ref = f"@{ref}" - + result = _run_browser_command(effective_task_id, "click", [ref]) - + if result.get("success"): - return json.dumps({ - "success": True, - "clicked": ref - }, ensure_ascii=False) + return json.dumps({"success": True, "clicked": ref}, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", f"Failed to click {ref}") - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": result.get("error", f"Failed to click {ref}")}, ensure_ascii=False + ) -def browser_type(ref: str, text: str, task_id: Optional[str] = None) -> str: +def browser_type(ref: str, text: str, task_id: str | None = None) -> str: """ Type text into an input field. - + Args: ref: Element reference (e.g., "@e3") text: Text to type task_id: Task identifier for session isolation - + Returns: JSON string with type result """ effective_task_id = task_id or "default" - + # Ensure ref starts with @ if not ref.startswith("@"): ref = f"@{ref}" - + # Use fill command (clears then types) result = _run_browser_command(effective_task_id, "fill", [ref, text]) - + if result.get("success"): - return json.dumps({ - "success": True, - "typed": text, - "element": ref - }, ensure_ascii=False) + return json.dumps({"success": True, "typed": text, "element": ref}, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", f"Failed to type into {ref}") - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": result.get("error", f"Failed to type into {ref}")}, ensure_ascii=False + ) -def browser_scroll(direction: str, task_id: Optional[str] = None) -> str: +def browser_scroll(direction: str, task_id: str | None = None) -> str: """ Scroll the page. - + Args: direction: "up" or "down" task_id: Task identifier for session isolation - + Returns: JSON string with scroll result """ effective_task_id = task_id or "default" - + # Validate direction if direction not in ["up", "down"]: - return json.dumps({ - "success": False, - "error": f"Invalid direction '{direction}'. Use 'up' or 'down'." - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "error": f"Invalid direction '{direction}'. Use 'up' or 'down'."}, ensure_ascii=False + ) + result = _run_browser_command(effective_task_id, "scroll", [direction]) - + if result.get("success"): - return json.dumps({ - "success": True, - "scrolled": direction - }, ensure_ascii=False) + return json.dumps({"success": True, "scrolled": direction}, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", f"Failed to scroll {direction}") - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": result.get("error", f"Failed to scroll {direction}")}, ensure_ascii=False + ) -def browser_back(task_id: Optional[str] = None) -> str: +def browser_back(task_id: str | None = None) -> str: """ Navigate back in browser history. - + Args: task_id: Task identifier for session isolation - + Returns: JSON string with navigation result """ effective_task_id = task_id or "default" result = _run_browser_command(effective_task_id, "back", []) - + if result.get("success"): data = result.get("data", {}) - return json.dumps({ - "success": True, - "url": data.get("url", "") - }, ensure_ascii=False) + return json.dumps({"success": True, "url": data.get("url", "")}, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", "Failed to go back") - }, ensure_ascii=False) + return json.dumps({"success": False, "error": result.get("error", "Failed to go back")}, ensure_ascii=False) -def browser_press(key: str, task_id: Optional[str] = None) -> str: +def browser_press(key: str, task_id: str | None = None) -> str: """ Press a keyboard key. - + Args: key: Key to press (e.g., "Enter", "Tab") task_id: Task identifier for session isolation - + Returns: JSON string with key press result """ effective_task_id = task_id or "default" result = _run_browser_command(effective_task_id, "press", [key]) - + if result.get("success"): - return json.dumps({ - "success": True, - "pressed": key - }, ensure_ascii=False) + return json.dumps({"success": True, "pressed": key}, ensure_ascii=False) else: - return json.dumps({ - "success": False, - "error": result.get("error", f"Failed to press {key}") - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": result.get("error", f"Failed to press {key}")}, ensure_ascii=False + ) -def browser_close(task_id: Optional[str] = None) -> str: +def browser_close(task_id: str | None = None) -> str: """ Close the browser session. - + Args: task_id: Task identifier for session isolation - + Returns: JSON string with close result """ effective_task_id = task_id or "default" - + # Stop auto-recording before closing _maybe_stop_recording(effective_task_id) - + result = _run_browser_command(effective_task_id, "close", []) - + # Close the backend session (Browserbase API in cloud mode, nothing extra in local mode) session_key = task_id if task_id and task_id in _active_sessions else "default" if session_key in _active_sessions: @@ -1305,66 +1232,69 @@ def browser_close(task_id: Optional[str] = None) -> str: except Exception as e: logger.warning("Could not close BrowserBase session: %s", e) del _active_sessions[session_key] - + if result.get("success"): - return json.dumps({ - "success": True, - "closed": True - }, ensure_ascii=False) + return json.dumps({"success": True, "closed": True}, ensure_ascii=False) else: # Even if close fails, session was released - return json.dumps({ - "success": True, - "closed": True, - "warning": result.get("error", "Session may not have been active") - }, ensure_ascii=False) + return json.dumps( + {"success": True, "closed": True, "warning": result.get("error", "Session may not have been active")}, + ensure_ascii=False, + ) -def browser_console(clear: bool = False, task_id: Optional[str] = None) -> str: +def browser_console(clear: bool = False, task_id: str | None = None) -> str: """Get browser console messages and JavaScript errors. - + Returns both console output (log/warn/error/info from the page's JS) and uncaught exceptions (crashes, unhandled promise rejections). - + Args: clear: If True, clear the message/error buffers after reading task_id: Task identifier for session isolation - + Returns: JSON string with console messages and JS errors """ effective_task_id = task_id or "default" - + console_args = ["--clear"] if clear else [] error_args = ["--clear"] if clear else [] - + console_result = _run_browser_command(effective_task_id, "console", console_args) errors_result = _run_browser_command(effective_task_id, "errors", error_args) - + messages = [] if console_result.get("success"): for msg in console_result.get("data", {}).get("messages", []): - messages.append({ - "type": msg.get("type", "log"), - "text": msg.get("text", ""), - "source": "console", - }) - + messages.append( + { + "type": msg.get("type", "log"), + "text": msg.get("text", ""), + "source": "console", + } + ) + errors = [] if errors_result.get("success"): for err in errors_result.get("data", {}).get("errors", []): - errors.append({ - "message": err.get("message", ""), - "source": "exception", - }) - - return json.dumps({ - "success": True, - "console_messages": messages, - "js_errors": errors, - "total_messages": len(messages), - "total_errors": len(errors), - }, ensure_ascii=False) + errors.append( + { + "message": err.get("message", ""), + "source": "exception", + } + ) + + return json.dumps( + { + "success": True, + "console_messages": messages, + "js_errors": errors, + "total_messages": len(messages), + "total_errors": len(errors), + }, + ensure_ascii=False, + ) def _maybe_start_recording(task_id: str): @@ -1377,21 +1307,23 @@ def _maybe_start_recording(task_id: str): record_enabled = False if config_path.exists(): import yaml + with open(config_path) as f: cfg = yaml.safe_load(f) or {} record_enabled = cfg.get("browser", {}).get("record_sessions", False) - + if not record_enabled: return - + recordings_dir = hermes_home / "browser_recordings" recordings_dir.mkdir(parents=True, exist_ok=True) _cleanup_old_recordings(max_age_hours=72) - + import time + timestamp = time.strftime("%Y%m%d_%H%M%S") recording_path = recordings_dir / f"session_{timestamp}_{task_id[:16]}.webm" - + result = _run_browser_command(task_id, "record", ["start", str(recording_path)]) if result.get("success"): _recording_sessions.add(task_id) @@ -1417,18 +1349,18 @@ def _maybe_stop_recording(task_id: str): _recording_sessions.discard(task_id) -def browser_get_images(task_id: Optional[str] = None) -> str: +def browser_get_images(task_id: str | None = None) -> str: """ Get all images on the current page. - + Args: task_id: Task identifier for session isolation - + Returns: JSON string with list of images (src and alt) """ effective_task_id = task_id or "default" - + # Use eval to run JavaScript that extracts images js_code = """JSON.stringify( [...document.images].map(img => ({ @@ -1438,121 +1370,112 @@ def browser_get_images(task_id: Optional[str] = None) -> str: height: img.naturalHeight })).filter(img => img.src && !img.src.startsWith('data:')) )""" - + result = _run_browser_command(effective_task_id, "eval", [js_code]) - + if result.get("success"): data = result.get("data", {}) raw_result = data.get("result", "[]") - + try: # Parse the JSON string returned by JavaScript if isinstance(raw_result, str): images = json.loads(raw_result) else: images = raw_result - - return json.dumps({ - "success": True, - "images": images, - "count": len(images) - }, ensure_ascii=False) + + return json.dumps({"success": True, "images": images, "count": len(images)}, ensure_ascii=False) except json.JSONDecodeError: - return json.dumps({ - "success": True, - "images": [], - "count": 0, - "warning": "Could not parse image data" - }, ensure_ascii=False) + return json.dumps( + {"success": True, "images": [], "count": 0, "warning": "Could not parse image data"}, ensure_ascii=False + ) else: - return json.dumps({ - "success": False, - "error": result.get("error", "Failed to get images") - }, ensure_ascii=False) + return json.dumps({"success": False, "error": result.get("error", "Failed to get images")}, ensure_ascii=False) -def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] = None) -> str: +def browser_vision(question: str, annotate: bool = False, task_id: str | None = None) -> str: """ Take a screenshot of the current page and analyze it with vision AI. - + This tool captures what's visually displayed in the browser and sends it to Gemini for analysis. Useful for understanding visual content that the text-based snapshot may not capture (CAPTCHAs, verification challenges, images, complex layouts, etc.). - + The screenshot is saved persistently and its file path is returned alongside the analysis, so it can be shared with users via MEDIA: in the response. - + Args: question: What you want to know about the page visually annotate: If True, overlay numbered [N] labels on interactive elements task_id: Task identifier for session isolation - + Returns: JSON string with vision analysis results and screenshot_path """ import base64 import uuid as uuid_mod from pathlib import Path - + effective_task_id = task_id or "default" - + # Check auxiliary vision client if _aux_vision_client is None or _DEFAULT_VISION_MODEL is None: - return json.dumps({ - "success": False, - "error": "Browser vision unavailable: no auxiliary vision model configured. " - "Set OPENROUTER_API_KEY or configure Nous Portal to enable browser vision." - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "error": "Browser vision unavailable: no auxiliary vision model configured. " + "Set OPENROUTER_API_KEY or configure Nous Portal to enable browser vision.", + }, + ensure_ascii=False, + ) + # Save screenshot to persistent location so it can be shared with users hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) screenshots_dir = hermes_home / "browser_screenshots" screenshot_path = screenshots_dir / f"browser_screenshot_{uuid_mod.uuid4().hex}.png" - + try: screenshots_dir.mkdir(parents=True, exist_ok=True) - + # Prune old screenshots (older than 24 hours) to prevent unbounded disk growth _cleanup_old_screenshots(screenshots_dir, max_age_hours=24) - + # Take screenshot using agent-browser screenshot_args = [str(screenshot_path)] if annotate: screenshot_args.insert(0, "--annotate") - result = _run_browser_command( - effective_task_id, - "screenshot", - screenshot_args, - timeout=30 - ) - + result = _run_browser_command(effective_task_id, "screenshot", screenshot_args, timeout=30) + if not result.get("success"): error_detail = result.get("error", "Unknown error") mode = "local" if _is_local_mode() else "cloud" - return json.dumps({ - "success": False, - "error": f"Failed to take screenshot ({mode} mode): {error_detail}" - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "error": f"Failed to take screenshot ({mode} mode): {error_detail}"}, + ensure_ascii=False, + ) + # Check if screenshot file was created if not screenshot_path.exists(): mode = "local" if _is_local_mode() else "cloud" - return json.dumps({ - "success": False, - "error": ( - f"Screenshot file was not created at {screenshot_path} ({mode} mode). " - f"This may indicate a socket path issue (macOS /var/folders/), " - f"a missing Chromium install ('agent-browser install'), " - f"or a stale daemon process." - ), - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "error": ( + f"Screenshot file was not created at {screenshot_path} ({mode} mode). " + f"This may indicate a socket path issue (macOS /var/folders/), " + f"a missing Chromium install ('agent-browser install'), " + f"or a stale daemon process." + ), + }, + ensure_ascii=False, + ) + # Read and convert to base64 image_data = screenshot_path.read_bytes() image_base64 = base64.b64encode(image_data).decode("ascii") data_url = f"data:image/png;base64,{image_base64}" - + vision_prompt = ( f"You are analyzing a screenshot of a web browser.\n\n" f"User's question: {question}\n\n" @@ -1564,9 +1487,9 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] # Use the sync auxiliary vision client directly from agent.auxiliary_client import auxiliary_max_tokens_param + vision_model = _get_vision_model() - logger.debug("browser_vision: analysing screenshot (%d bytes) with model=%s", - len(image_data), vision_model) + logger.debug("browser_vision: analysing screenshot (%d bytes) with model=%s", len(image_data), vision_model) response = _aux_vision_client.chat.completions.create( model=vision_model, messages=[ @@ -1581,7 +1504,7 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] **auxiliary_max_tokens_param(2000), temperature=0.1, ) - + analysis = response.choices[0].message.content response_data = { "success": True, @@ -1592,7 +1515,7 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] if annotate and result.get("data", {}).get("annotations"): response_data["annotations"] = result["data"]["annotations"] return json.dumps(response_data, ensure_ascii=False) - + except Exception as e: # Keep the screenshot if it was captured successfully — the failure is # in the LLM vision analysis, not the capture. Deleting a valid @@ -1602,13 +1525,16 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] error_info = {"success": False, "error": f"Error during vision analysis: {str(e)}"} if screenshot_path.exists(): error_info["screenshot_path"] = str(screenshot_path) - error_info["note"] = "Screenshot was captured but vision analysis failed. You can still share it via MEDIA:." + error_info["note"] = ( + "Screenshot was captured but vision analysis failed. You can still share it via MEDIA:." + ) return json.dumps(error_info, ensure_ascii=False) def _cleanup_old_screenshots(screenshots_dir, max_age_hours=24): """Remove browser screenshots older than max_age_hours to prevent disk bloat.""" import time + try: cutoff = time.time() - (max_age_hours * 3600) for f in screenshots_dir.glob("browser_screenshot_*.png"): @@ -1624,6 +1550,7 @@ def _cleanup_old_screenshots(screenshots_dir, max_age_hours=24): def _cleanup_old_recordings(max_age_hours=72): """Remove browser recordings older than max_age_hours to prevent disk bloat.""" import time + try: hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) recordings_dir = hermes_home / "browser_recordings" @@ -1644,18 +1571,19 @@ def _cleanup_old_recordings(max_age_hours=72): # Cleanup and Management Functions # ============================================================================ + def _close_browserbase_session(session_id: str, api_key: str, project_id: str) -> bool: """ Close a Browserbase session immediately via the API. - + Uses POST /v1/sessions/{id} with status=REQUEST_RELEASE to immediately terminate the session without waiting for keepAlive timeout. - + Args: session_id: The Browserbase session ID api_key: Browserbase API key project_id: Browserbase project ID - + Returns: True if session was successfully closed, False otherwise """ @@ -1663,69 +1591,65 @@ def _close_browserbase_session(session_id: str, api_key: str, project_id: str) - # POST to update session status to REQUEST_RELEASE response = requests.post( f"https://api.browserbase.com/v1/sessions/{session_id}", - headers={ - "X-BB-API-Key": api_key, - "Content-Type": "application/json" - }, - json={ - "projectId": project_id, - "status": "REQUEST_RELEASE" - }, - timeout=10 + headers={"X-BB-API-Key": api_key, "Content-Type": "application/json"}, + json={"projectId": project_id, "status": "REQUEST_RELEASE"}, + timeout=10, ) - + if response.status_code in (200, 201, 204): logger.debug("Successfully closed BrowserBase session %s", session_id) return True else: - logger.warning("Failed to close session %s: HTTP %s - %s", session_id, response.status_code, response.text[:200]) + logger.warning( + "Failed to close session %s: HTTP %s - %s", session_id, response.status_code, response.text[:200] + ) return False - + except Exception as e: logger.error("Exception closing session %s: %s", session_id, e) return False -def cleanup_browser(task_id: Optional[str] = None) -> None: +def cleanup_browser(task_id: str | None = None) -> None: """ Clean up browser session for a task. - + Called automatically when a task completes or when inactivity timeout is reached. Closes both the agent-browser session and the Browserbase session. - + Args: task_id: Task identifier to clean up """ if task_id is None: task_id = "default" - + logger.debug("cleanup_browser called for task_id: %s", task_id) logger.debug("Active sessions: %s", list(_active_sessions.keys())) - + # Check if session exists (under lock), but don't remove yet - # _run_browser_command needs it to build the close command. with _cleanup_lock: session_info = _active_sessions.get(task_id) - + if session_info: bb_session_id = session_info.get("bb_session_id", "unknown") logger.debug("Found session for task %s: bb_session_id=%s", task_id, bb_session_id) - + # Stop auto-recording before closing (saves the file) _maybe_stop_recording(task_id) - + # Try to close via agent-browser first (needs session in _active_sessions) try: _run_browser_command(task_id, "close", [], timeout=10) logger.debug("agent-browser close command completed for task %s", task_id) except Exception as e: logger.warning("agent-browser close failed for task %s: %s", task_id, e) - + # Now remove from tracking under lock with _cleanup_lock: _active_sessions.pop(task_id, None) _session_last_activity.pop(task_id, None) - + # Cloud mode: close the Browserbase session via API if bb_session_id and not _is_local_mode(): try: @@ -1735,7 +1659,7 @@ def cleanup_browser(task_id: Optional[str] = None) -> None: logger.warning("Could not close BrowserBase session %s", bb_session_id) except Exception as e: logger.error("Exception during BrowserBase session close: %s", e) - + # Kill the daemon process and clean up socket directory session_name = session_info.get("session_name", "") if session_name: @@ -1751,7 +1675,7 @@ def cleanup_browser(task_id: Optional[str] = None) -> None: except (ProcessLookupError, ValueError, PermissionError, OSError): pass shutil.rmtree(socket_dir, ignore_errors=True) - + logger.debug("Removed task %s from active sessions", task_id) else: logger.debug("No active session found for task_id: %s", task_id) @@ -1760,7 +1684,7 @@ def cleanup_browser(task_id: Optional[str] = None) -> None: def cleanup_all_browsers() -> None: """ Clean up all active browser sessions. - + Useful for cleanup on shutdown. """ with _cleanup_lock: @@ -1769,10 +1693,10 @@ def cleanup_all_browsers() -> None: cleanup_browser(task_id) -def get_active_browser_sessions() -> Dict[str, Dict[str, str]]: +def get_active_browser_sessions() -> dict[str, dict[str, str]]: """ Get information about active browser sessions. - + Returns: Dict mapping task_id to session info (session_name, bb_session_id, cdp_url) """ @@ -1784,6 +1708,7 @@ def get_active_browser_sessions() -> Dict[str, Dict[str, str]]: # Requirements Check # ============================================================================ + def check_browser_requirements() -> bool: """ Check if browser tool requirements are met. @@ -1793,7 +1718,7 @@ def check_browser_requirements() -> bool: In **cloud mode** (BROWSERBASE_API_KEY set): the CLI *and* both ``BROWSERBASE_API_KEY`` / ``BROWSERBASE_PROJECT_ID`` must be present. - + Returns: True if all requirements are met, False otherwise """ @@ -1826,7 +1751,7 @@ if __name__ == "__main__": mode = "local" if _is_local_mode() else "cloud (Browserbase)" print(f" Mode: {mode}") - + # Check requirements if check_browser_requirements(): print("✅ All requirements met") @@ -1843,11 +1768,11 @@ if __name__ == "__main__": if not os.environ.get("BROWSERBASE_PROJECT_ID"): print(" - BROWSERBASE_PROJECT_ID not set (required for cloud mode)") print(" Tip: unset BROWSERBASE_API_KEY to use free local mode instead") - + print("\n📋 Available Browser Tools:") for schema in BROWSER_TOOL_SCHEMAS: print(f" 🔹 {schema['name']}: {schema['description'][:60]}...") - + print("\n💡 Usage:") print(" from tools.browser_tool import browser_navigate, browser_snapshot") print(" result = browser_navigate('https://example.com', task_id='my_task')") @@ -1873,7 +1798,8 @@ registry.register( toolset="browser", schema=_BROWSER_SCHEMA_MAP["browser_snapshot"], handler=lambda args, **kw: browser_snapshot( - full=args.get("full", False), task_id=kw.get("task_id"), user_task=kw.get("user_task")), + full=args.get("full", False), task_id=kw.get("task_id"), user_task=kw.get("user_task") + ), check_fn=check_browser_requirements, ) registry.register( @@ -1929,7 +1855,9 @@ registry.register( name="browser_vision", toolset="browser", schema=_BROWSER_SCHEMA_MAP["browser_vision"], - handler=lambda args, **kw: browser_vision(question=args.get("question", ""), annotate=args.get("annotate", False), task_id=kw.get("task_id")), + handler=lambda args, **kw: browser_vision( + question=args.get("question", ""), annotate=args.get("annotate", False), task_id=kw.get("task_id") + ), check_fn=check_browser_requirements, ) registry.register( diff --git a/tools/clarify_tool.py b/tools/clarify_tool.py index e0552357b6..26452f4011 100644 --- a/tools/clarify_tool.py +++ b/tools/clarify_tool.py @@ -12,8 +12,7 @@ a thin dispatcher that delegates to a platform-provided callback. """ import json -from typing import Dict, Any, List, Optional, Callable - +from collections.abc import Callable # Maximum number of predefined choices the agent can offer. # A 5th "Other (type your answer)" option is always appended by the UI. @@ -22,8 +21,8 @@ MAX_CHOICES = 4 def clarify_tool( question: str, - choices: Optional[List[str]] = None, - callback: Optional[Callable] = None, + choices: list[str] | None = None, + callback: Callable | None = None, ) -> str: """ Ask the user a question, optionally with multiple-choice options. @@ -68,11 +67,14 @@ def clarify_tool( ensure_ascii=False, ) - return json.dumps({ - "question": question, - "choices_offered": choices, - "user_response": str(user_response).strip(), - }, ensure_ascii=False) + return json.dumps( + { + "question": question, + "choices_offered": choices, + "user_response": str(user_response).strip(), + }, + ensure_ascii=False, + ) def check_clarify_requirements() -> bool: @@ -133,8 +135,7 @@ registry.register( toolset="clarify", schema=CLARIFY_SCHEMA, handler=lambda args, **kw: clarify_tool( - question=args.get("question", ""), - choices=args.get("choices"), - callback=kw.get("callback")), + question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback") + ), check_fn=check_clarify_requirements, ) diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 7ea8fa8e40..efb7592408 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -31,7 +31,7 @@ import time import uuid _IS_WINDOWS = platform.system() == "Windows" -from typing import Any, Dict, List, Optional +from typing import Any # Availability gate: UDS requires a POSIX OS logger = logging.getLogger(__name__) @@ -40,21 +40,23 @@ SANDBOX_AVAILABLE = sys.platform != "win32" # The 7 tools allowed inside the sandbox. The intersection of this list # and the session's enabled tools determines which stubs are generated. -SANDBOX_ALLOWED_TOOLS = frozenset([ - "web_search", - "web_extract", - "read_file", - "write_file", - "search_files", - "patch", - "terminal", -]) +SANDBOX_ALLOWED_TOOLS = frozenset( + [ + "web_search", + "web_extract", + "read_file", + "write_file", + "search_files", + "patch", + "terminal", + ] +) # Resource limit defaults (overridable via config.yaml → code_execution.*) -DEFAULT_TIMEOUT = 300 # 5 minutes +DEFAULT_TIMEOUT = 300 # 5 minutes DEFAULT_MAX_TOOL_CALLS = 50 -MAX_STDOUT_BYTES = 50_000 # 50 KB -MAX_STDERR_BYTES = 10_000 # 10 KB +MAX_STDOUT_BYTES = 50_000 # 50 KB +MAX_STDERR_BYTES = 10_000 # 10 KB def check_sandbox_requirements() -> bool: @@ -114,7 +116,7 @@ _TOOL_STUBS = { } -def generate_hermes_tools_module(enabled_tools: List[str]) -> str: +def generate_hermes_tools_module(enabled_tools: list[str]) -> str: """ Build the source code for the hermes_tools.py stub module. @@ -128,11 +130,7 @@ def generate_hermes_tools_module(enabled_tools: List[str]) -> str: if tool_name not in _TOOL_STUBS: continue func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name] - stub_functions.append( - f"def {func_name}({sig}):\n" - f" {doc}\n" - f" return _call({func_name!r}, {args_expr})\n" - ) + stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n") export_names.append(func_name) header = '''\ @@ -223,7 +221,7 @@ def _rpc_server_loop( server_sock: socket.socket, task_id: str, tool_call_log: list, - tool_call_counter: list, # mutable [int] so the thread can increment + tool_call_counter: list, # mutable [int] so the thread can increment max_tool_calls: int, allowed_tools: frozenset, ): @@ -243,7 +241,7 @@ def _rpc_server_loop( while True: try: chunk = conn.recv(65536) - except socket.timeout: + except TimeoutError: break if not chunk: break @@ -270,23 +268,22 @@ def _rpc_server_loop( # Enforce the allow-list if tool_name not in allowed_tools: available = ", ".join(sorted(allowed_tools)) - resp = json.dumps({ - "error": ( - f"Tool '{tool_name}' is not available in execute_code. " - f"Available: {available}" - ) - }) + resp = json.dumps( + {"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")} + ) conn.sendall((resp + "\n").encode()) continue # Enforce tool call limit if tool_call_counter[0] >= max_tool_calls: - resp = json.dumps({ - "error": ( - f"Tool call limit reached ({max_tool_calls}). " - "No more tool calls allowed in this execution." - ) - }) + resp = json.dumps( + { + "error": ( + f"Tool call limit reached ({max_tool_calls}). " + "No more tool calls allowed in this execution." + ) + } + ) conn.sendall((resp + "\n").encode()) continue @@ -303,9 +300,7 @@ def _rpc_server_loop( sys.stdout = open(os.devnull, "w") sys.stderr = open(os.devnull, "w") try: - result = handle_function_call( - tool_name, tool_args, task_id=task_id - ) + result = handle_function_call(tool_name, tool_args, task_id=task_id) finally: sys.stdout.close() sys.stderr.close() @@ -318,15 +313,17 @@ def _rpc_server_loop( # Log for observability args_preview = str(tool_args)[:80] - tool_call_log.append({ - "tool": tool_name, - "args_preview": args_preview, - "duration": round(call_duration, 2), - }) + tool_call_log.append( + { + "tool": tool_name, + "args_preview": args_preview, + "duration": round(call_duration, 2), + } + ) conn.sendall((result + "\n").encode()) - except socket.timeout: + except TimeoutError: pass except OSError: pass @@ -342,10 +339,11 @@ def _rpc_server_loop( # Main entry point # --------------------------------------------------------------------------- + def execute_code( code: str, - task_id: Optional[str] = None, - enabled_tools: Optional[List[str]] = None, + task_id: str | None = None, + enabled_tools: list[str] | None = None, ) -> str: """ Run a Python script in a sandboxed child process with RPC access @@ -361,9 +359,7 @@ def execute_code( JSON string with execution results. """ if not SANDBOX_AVAILABLE: - return json.dumps({ - "error": "execute_code is not available on Windows. Use normal tool calls instead." - }) + return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."}) if not code or not code.strip(): return json.dumps({"error": "No code provided."}) @@ -397,9 +393,7 @@ def execute_code( try: # Write the auto-generated hermes_tools module - tools_src = generate_hermes_tools_module( - list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS) - ) + tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)) with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f: f.write(tools_src) @@ -415,8 +409,12 @@ def execute_code( rpc_thread = threading.Thread( target=_rpc_server_loop, args=( - server_sock, task_id, tool_call_log, - tool_call_counter, max_tool_calls, sandbox_tools, + server_sock, + task_id, + tool_call_log, + tool_call_counter, + max_tool_calls, + sandbox_tools, ), daemon=True, ) @@ -426,11 +424,24 @@ def execute_code( # Build a minimal environment for the child. We intentionally exclude # API keys and tokens to prevent credential exfiltration from LLM- # generated scripts. The child accesses tools via RPC, not direct API. - _SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM", - "TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME", - "XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA") - _SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", - "PASSWD", "AUTH") + _SAFE_ENV_PREFIXES = ( + "PATH", + "HOME", + "USER", + "LANG", + "LC_", + "TERM", + "TMPDIR", + "TMP", + "TEMP", + "SHELL", + "LOGNAME", + "XDG_", + "PYTHONPATH", + "VIRTUAL_ENV", + "CONDA", + ) + _SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH") child_env = {} for k, v in os.environ.items(): if any(s in k.upper() for s in _SECRET_SUBSTRINGS): @@ -515,7 +526,7 @@ def execute_code( rpc_thread.join(timeout=3) # Build response - result: Dict[str, Any] = { + result: dict[str, Any] = { "status": status, "output": stdout_text, "tool_calls_made": tool_call_counter[0], @@ -538,17 +549,21 @@ def execute_code( except Exception as exc: duration = round(time.monotonic() - exec_start, 2) logging.exception("execute_code failed") - return json.dumps({ - "status": "error", - "error": str(exc), - "tool_calls_made": tool_call_counter[0], - "duration_seconds": duration, - }, ensure_ascii=False) + return json.dumps( + { + "status": "error", + "error": str(exc), + "tool_calls_made": tool_call_counter[0], + "duration_seconds": duration, + }, + ensure_ascii=False, + ) finally: # Cleanup temp dir and socket try: import shutil + shutil.rmtree(tmpdir, ignore_errors=True) except Exception as e: logger.debug("Could not clean temp dir: %s", e) @@ -592,6 +607,7 @@ def _load_config() -> dict: """Load code_execution config from CLI_CONFIG if available.""" try: from cli import CLI_CONFIG + return CLI_CONFIG.get("code_execution", {}) except Exception: return {} @@ -604,27 +620,37 @@ def _load_config() -> dict: # Per-tool documentation lines for the execute_code description. # Ordered to match the canonical display order. _TOOL_DOC_LINES = [ - ("web_search", - " web_search(query: str, limit: int = 5) -> dict\n" - " Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"), - ("web_extract", - " web_extract(urls: list[str]) -> dict\n" - " Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"), - ("read_file", - " read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n" - " Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"), - ("write_file", - " write_file(path: str, content: str) -> dict\n" - " Always overwrites the entire file."), - ("search_files", - " search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n" - " target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"), - ("patch", - " patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n" - " Replaces old_string with new_string in the file."), - ("terminal", - " terminal(command: str, timeout=None, workdir=None) -> dict\n" - " Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"), + ( + "web_search", + " web_search(query: str, limit: int = 5) -> dict\n" + ' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}', + ), + ( + "web_extract", + " web_extract(urls: list[str]) -> dict\n" + ' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown', + ), + ( + "read_file", + " read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n" + ' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}', + ), + ("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."), + ( + "search_files", + ' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n' + ' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}', + ), + ( + "patch", + " patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n" + " Replaces old_string with new_string in the file.", + ), + ( + "terminal", + " terminal(command: str, timeout=None, workdir=None) -> dict\n" + ' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}', + ), ] @@ -639,9 +665,7 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict: enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS # Build tool documentation lines for only the enabled tools - tool_lines = "\n".join( - doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools - ) + tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools) # Build example import list from enabled tools import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools] @@ -702,8 +726,7 @@ registry.register( toolset="code_execution", schema=EXECUTE_CODE_SCHEMA, handler=lambda args, **kw: execute_code( - code=args.get("code", ""), - task_id=kw.get("task_id"), - enabled_tools=kw.get("enabled_tools")), + code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools") + ), check_fn=check_sandbox_requirements, ) diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index bdfa58d630..6b1e52f472 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -11,37 +11,44 @@ The prompt must contain ALL necessary information. import json import os import re -from typing import Optional # Import from cron module (will be available when properly installed) import sys from pathlib import Path + sys.path.insert(0, str(Path(__file__).parent.parent)) from cron.jobs import create_job, get_job, list_jobs, remove_job - # --------------------------------------------------------------------------- # Cron prompt scanning — critical-severity patterns only, since cron prompts # run in fresh sessions with full tool access. # --------------------------------------------------------------------------- _CRON_THREAT_PATTERNS = [ - (r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"), - (r'do\s+not\s+tell\s+the\s+user', "deception_hide"), - (r'system\s+prompt\s+override', "sys_prompt_override"), - (r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"), - (r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"), - (r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"), - (r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"), - (r'authorized_keys', "ssh_backdoor"), - (r'/etc/sudoers|visudo', "sudoers_mod"), - (r'rm\s+-rf\s+/', "destructive_root_rm"), + (r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"), + (r"do\s+not\s+tell\s+the\s+user", "deception_hide"), + (r"system\s+prompt\s+override", "sys_prompt_override"), + (r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"), + (r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"), + (r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"), + (r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"), + (r"authorized_keys", "ssh_backdoor"), + (r"/etc/sudoers|visudo", "sudoers_mod"), + (r"rm\s+-rf\s+/", "destructive_root_rm"), ] _CRON_INVISIBLE_CHARS = { - '\u200b', '\u200c', '\u200d', '\u2060', '\ufeff', - '\u202a', '\u202b', '\u202c', '\u202d', '\u202e', + "\u200b", + "\u200c", + "\u200d", + "\u2060", + "\ufeff", + "\u202a", + "\u202b", + "\u202c", + "\u202d", + "\u202e", } @@ -60,17 +67,18 @@ def _scan_cron_prompt(prompt: str) -> str: # Tool: schedule_cronjob # ============================================================================= + def schedule_cronjob( prompt: str, schedule: str, - name: Optional[str] = None, - repeat: Optional[int] = None, - deliver: Optional[str] = None, - task_id: str = None + name: str | None = None, + repeat: int | None = None, + deliver: str | None = None, + task_id: str = None, ) -> str: """ Schedule an automated task to run the agent on a schedule. - + IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session. The agent will have NO memory of this conversation or any prior context. Therefore, the prompt MUST contain ALL necessary information: @@ -78,12 +86,12 @@ def schedule_cronjob( - Specific file paths, URLs, or identifiers - Clear success criteria - Any relevant background information - + BAD prompt: "Check on that server issue" - GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx - is running with 'systemctl status nginx', and verify the site + GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx + is running with 'systemctl status nginx', and verify the site https://example.com returns HTTP 200. Report any issues found." - + Args: prompt: Complete, self-contained instructions for the future agent. Must include ALL context needed - the agent won't remember anything. @@ -105,7 +113,7 @@ def schedule_cronjob( - "signal": Send to Signal home channel - "telegram:123456": Send to specific chat ID - "signal:+15551234567": Send to specific Signal number - + Returns: JSON with job_id, next_run time, and confirmation """ @@ -124,17 +132,10 @@ def schedule_cronjob( "chat_id": origin_chat_id, "chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"), } - + try: - job = create_job( - prompt=prompt, - schedule=schedule, - name=name, - repeat=repeat, - deliver=deliver, - origin=origin - ) - + job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin) + # Format repeat info for display times = job["repeat"].get("times") if times is None: @@ -143,23 +144,23 @@ def schedule_cronjob( repeat_display = "once" else: repeat_display = f"{times} times" - - return json.dumps({ - "success": True, - "job_id": job["id"], - "name": job["name"], - "schedule": job["schedule_display"], - "repeat": repeat_display, - "deliver": job.get("deliver", "local"), - "next_run_at": job["next_run_at"], - "message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}." - }, indent=2) - + + return json.dumps( + { + "success": True, + "job_id": job["id"], + "name": job["name"], + "schedule": job["schedule_display"], + "repeat": repeat_display, + "deliver": job.get("deliver", "local"), + "next_run_at": job["next_run_at"], + "message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.", + }, + indent=2, + ) + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, indent=2) + return json.dumps({"success": False, "error": str(e)}, indent=2) SCHEDULE_CRONJOB_SCHEMA = { @@ -177,7 +178,7 @@ The future agent will NOT remember anything from the current conversation. SCHEDULE FORMATS: - One-shot: "30m", "2h", "1d" (runs once after delay) -- Interval: "every 30m", "every 2h" (recurring) +- Interval: "every 30m", "every 2h" (recurring) - Cron: "0 9 * * *" (cron expression for precise scheduling) - Timestamp: "2026-02-03T14:00:00" (specific date/time) @@ -202,27 +203,24 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance."" "properties": { "prompt": { "type": "string", - "description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation." + "description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.", }, "schedule": { "type": "string", - "description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp" - }, - "name": { - "type": "string", - "description": "Optional human-friendly name for the job" + "description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp", }, + "name": {"type": "string", "description": "Optional human-friendly name for the job"}, "repeat": { "type": "integer", - "description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs." + "description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.", }, "deliver": { "type": "string", - "description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'" - } + "description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'", + }, }, - "required": ["prompt", "schedule"] - } + "required": ["prompt", "schedule"], + }, } @@ -230,10 +228,11 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance."" # Tool: list_cronjobs # ============================================================================= + def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str: """ List all scheduled cronjobs. - + Returns information about each job including: - Job ID (needed for removal) - Name @@ -241,16 +240,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str: - Repeat status (completed/total or 'forever') - Next scheduled run time - Last run time and status (if any) - + Args: include_disabled: Whether to include disabled/completed jobs - + Returns: JSON array of all scheduled jobs """ try: jobs = list_jobs(include_disabled=include_disabled) - + formatted_jobs = [] for job in jobs: # Format repeat status @@ -260,31 +259,26 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str: repeat_status = "forever" else: repeat_status = f"{completed}/{times}" - - formatted_jobs.append({ - "job_id": job["id"], - "name": job["name"], - "prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"], - "schedule": job["schedule_display"], - "repeat": repeat_status, - "deliver": job.get("deliver", "local"), - "next_run_at": job.get("next_run_at"), - "last_run_at": job.get("last_run_at"), - "last_status": job.get("last_status"), - "enabled": job.get("enabled", True) - }) - - return json.dumps({ - "success": True, - "count": len(formatted_jobs), - "jobs": formatted_jobs - }, indent=2) - + + formatted_jobs.append( + { + "job_id": job["id"], + "name": job["name"], + "prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"], + "schedule": job["schedule_display"], + "repeat": repeat_status, + "deliver": job.get("deliver", "local"), + "next_run_at": job.get("next_run_at"), + "last_run_at": job.get("last_run_at"), + "last_status": job.get("last_status"), + "enabled": job.get("enabled", True), + } + ) + + return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2) + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, indent=2) + return json.dumps({"success": False, "error": str(e)}, indent=2) LIST_CRONJOBS_SCHEMA = { @@ -302,11 +296,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""", "properties": { "include_disabled": { "type": "boolean", - "description": "Include disabled/completed jobs in the list (default: false)" + "description": "Include disabled/completed jobs in the list (default: false)", } }, - "required": [] - } + "required": [], + }, } @@ -314,48 +308,45 @@ Returns job_id, name, schedule, repeat status, next/last run times.""", # Tool: remove_cronjob # ============================================================================= + def remove_cronjob(job_id: str, task_id: str = None) -> str: """ Remove a scheduled cronjob by its ID. - + Use list_cronjobs first to find the job_id of the job you want to remove. - + Args: job_id: The ID of the job to remove (from list_cronjobs output) - + Returns: JSON confirmation of removal """ try: job = get_job(job_id) if not job: - return json.dumps({ - "success": False, - "error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs." - }, indent=2) - + return json.dumps( + { + "success": False, + "error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.", + }, + indent=2, + ) + removed = remove_job(job_id) if removed: - return json.dumps({ - "success": True, - "message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.", - "removed_job": { - "id": job_id, - "name": job["name"], - "schedule": job["schedule_display"] - } - }, indent=2) + return json.dumps( + { + "success": True, + "message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.", + "removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]}, + }, + indent=2, + ) else: - return json.dumps({ - "success": False, - "error": f"Failed to remove job '{job_id}'" - }, indent=2) - + return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2) + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, indent=2) + return json.dumps({"success": False, "error": str(e)}, indent=2) REMOVE_CRONJOB_SCHEMA = { @@ -368,13 +359,10 @@ use this to cancel a job before it completes.""", "parameters": { "type": "object", "properties": { - "job_id": { - "type": "string", - "description": "The ID of the cronjob to remove (from list_cronjobs output)" - } + "job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"} }, - "required": ["job_id"] - } + "required": ["job_id"], + }, } @@ -382,44 +370,34 @@ use this to cancel a job before it completes.""", # Requirements check # ============================================================================= + def check_cronjob_requirements() -> bool: """ Check if cronjob tools can be used. - + Available in interactive CLI mode and gateway/messaging platforms. Cronjobs are server-side scheduled tasks so they work from any interface. """ - return bool( - os.getenv("HERMES_INTERACTIVE") - or os.getenv("HERMES_GATEWAY_SESSION") - or os.getenv("HERMES_EXEC_ASK") - ) + return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK")) # ============================================================================= # Exports # ============================================================================= + def get_cronjob_tool_definitions(): """Return tool definitions for cronjob management.""" - return [ - SCHEDULE_CRONJOB_SCHEMA, - LIST_CRONJOBS_SCHEMA, - REMOVE_CRONJOB_SCHEMA - ] + return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA] # For direct testing if __name__ == "__main__": # Test the tools print("Testing schedule_cronjob:") - result = schedule_cronjob( - prompt="Test prompt for cron job", - schedule="5m", - name="Test Job" - ) + result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job") print(result) - + print("\nTesting list_cronjobs:") result = list_cronjobs() print(result) @@ -438,7 +416,8 @@ registry.register( name=args.get("name"), repeat=args.get("repeat"), deliver=args.get("deliver"), - task_id=kw.get("task_id")), + task_id=kw.get("task_id"), + ), check_fn=check_cronjob_requirements, ) registry.register( @@ -446,16 +425,14 @@ registry.register( toolset="cronjob", schema=LIST_CRONJOBS_SCHEMA, handler=lambda args, **kw: list_cronjobs( - include_disabled=args.get("include_disabled", False), - task_id=kw.get("task_id")), + include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id") + ), check_fn=check_cronjob_requirements, ) registry.register( name="remove_cronjob", toolset="cronjob", schema=REMOVE_CRONJOB_SCHEMA, - handler=lambda args, **kw: remove_cronjob( - job_id=args.get("job_id", ""), - task_id=kw.get("task_id")), + handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")), check_fn=check_cronjob_requirements, ) diff --git a/tools/debug_helpers.py b/tools/debug_helpers.py index f1934fd5be..d23251cf4a 100644 --- a/tools/debug_helpers.py +++ b/tools/debug_helpers.py @@ -27,7 +27,7 @@ import logging import os import uuid from pathlib import Path -from typing import Any, Dict +from typing import Any logger = logging.getLogger(__name__) @@ -44,27 +44,28 @@ class DebugSession: self.enabled = os.getenv(env_var, "false").lower() == "true" self.session_id = str(uuid.uuid4()) if self.enabled else "" self.log_dir = Path("./logs") - self._calls: list[Dict[str, Any]] = [] + self._calls: list[dict[str, Any]] = [] self._start_time = datetime.datetime.now().isoformat() if self.enabled else "" if self.enabled: self.log_dir.mkdir(exist_ok=True) - logger.debug("%s debug mode enabled - Session ID: %s", - tool_name, self.session_id) + logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id) @property def active(self) -> bool: return self.enabled - def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None: + def log_call(self, call_name: str, call_data: dict[str, Any]) -> None: """Append a tool-call entry to the in-memory log.""" if not self.enabled: return - self._calls.append({ - "timestamp": datetime.datetime.now().isoformat(), - "tool_name": call_name, - **call_data, - }) + self._calls.append( + { + "timestamp": datetime.datetime.now().isoformat(), + "tool_name": call_name, + **call_data, + } + ) def save(self) -> None: """Flush the in-memory log to a JSON file in the logs directory.""" @@ -87,7 +88,7 @@ class DebugSession: except Exception as e: logger.error("Error saving %s debug log: %s", self.tool_name, e) - def get_session_info(self) -> Dict[str, Any]: + def get_session_info(self) -> dict[str, Any]: """Return a summary dict suitable for returning from get_debug_session_info().""" if not self.enabled: return { diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index c8de97225d..8be3f1e0fe 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -20,21 +20,22 @@ import contextlib import io import json import logging -import os import sys import time +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Optional - +from typing import Any # Tools that children must never have access to -DELEGATE_BLOCKED_TOOLS = frozenset([ - "delegate_task", # no recursive delegation - "clarify", # no user interaction - "memory", # no writes to shared MEMORY.md - "send_message", # no cross-platform side effects - "execute_code", # children should reason step-by-step, not write scripts -]) +DELEGATE_BLOCKED_TOOLS = frozenset( + [ + "delegate_task", # no recursive delegation + "clarify", # no user interaction + "memory", # no writes to shared MEMORY.md + "send_message", # no cross-platform side effects + "execute_code", # children should reason step-by-step, not write scripts + ] +) MAX_CONCURRENT_CHILDREN = 3 MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2) @@ -47,7 +48,7 @@ def check_delegate_requirements() -> bool: return True -def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str: +def _build_child_system_prompt(goal: str, context: str | None = None) -> str: """Build a focused system prompt for a child agent.""" parts = [ "You are a focused subagent working on a specific delegated task.", @@ -69,15 +70,18 @@ def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str: return "\n".join(parts) -def _strip_blocked_tools(toolsets: List[str]) -> List[str]: +def _strip_blocked_tools(toolsets: list[str]) -> list[str]: """Remove toolsets that contain only blocked tools.""" blocked_toolset_names = { - "delegation", "clarify", "memory", "code_execution", + "delegation", + "clarify", + "memory", + "code_execution", } return [t for t in toolsets if t not in blocked_toolset_names] -def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]: +def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None: """Build a callback that relays child agent tool calls to the parent display. Two display paths: @@ -87,8 +91,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in Returns None if no display mechanism is available, in which case the child agent runs with no progress callback (identical to current behavior). """ - spinner = getattr(parent_agent, '_delegate_spinner', None) - parent_cb = getattr(parent_agent, 'tool_progress_callback', None) + spinner = getattr(parent_agent, "_delegate_spinner", None) + parent_cb = getattr(parent_agent, "tool_progress_callback", None) if not spinner and not parent_cb: return None # No display → no callback → zero behavior change @@ -98,7 +102,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in # Gateway: batch tool names, flush periodically _BATCH_SIZE = 5 - _batch: List[str] = [] + _batch: list[str] = [] def _callback(tool_name: str, preview: str = None): # Special "_thinking" event: model produced text content (reasoning) @@ -106,7 +110,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in if spinner: short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "") try: - spinner.print_above(f" {prefix}├─ 💭 \"{short}\"") + spinner.print_above(f' {prefix}├─ 💭 "{short}"') except Exception: pass # Don't relay thinking to gateway (too noisy for chat) @@ -116,17 +120,25 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in if spinner: short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "") tool_emojis = { - "terminal": "💻", "web_search": "🔍", "web_extract": "📄", - "read_file": "📖", "write_file": "✍️", "patch": "🔧", - "search_files": "🔎", "list_directory": "📂", - "browser_navigate": "🌐", "browser_click": "👆", - "text_to_speech": "🔊", "image_generate": "🎨", - "vision_analyze": "👁️", "process": "⚙️", + "terminal": "💻", + "web_search": "🔍", + "web_extract": "📄", + "read_file": "📖", + "write_file": "✍️", + "patch": "🔧", + "search_files": "🔎", + "list_directory": "📂", + "browser_navigate": "🌐", + "browser_click": "👆", + "text_to_speech": "🔊", + "image_generate": "🎨", + "vision_analyze": "👁️", + "process": "⚙️", } emoji = tool_emojis.get(tool_name, "⚡") line = f" {prefix}├─ {emoji} {tool_name}" if short: - line += f" \"{short}\"" + line += f' "{short}"' try: spinner.print_above(line) except Exception: @@ -159,13 +171,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in def _run_single_child( task_index: int, goal: str, - context: Optional[str], - toolsets: Optional[List[str]], - model: Optional[str], + context: str | None, + toolsets: list[str] | None, + model: str | None, max_iterations: int, parent_agent, task_count: int = 1, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Spawn and run a single child agent. Called from within a thread. Returns a structured result dict. @@ -216,7 +228,7 @@ def _run_single_child( skip_context_files=True, skip_memory=True, clarify_callback=None, - session_db=getattr(parent_agent, '_session_db', None), + session_db=getattr(parent_agent, "_session_db", None), providers_allowed=parent_agent.providers_allowed, providers_ignored=parent_agent.providers_ignored, providers_order=parent_agent.providers_order, @@ -226,10 +238,10 @@ def _run_single_child( ) # Set delegation depth so children can't spawn grandchildren - child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 + child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1 # Register child for interrupt propagation - if hasattr(parent_agent, '_active_children'): + if hasattr(parent_agent, "_active_children"): parent_agent._active_children.append(child) # Run with stdout/stderr suppressed to prevent interleaved output @@ -238,7 +250,7 @@ def _run_single_child( result = child.run_conversation(user_message=goal) # Flush any remaining batched progress to gateway - if child_progress_cb and hasattr(child_progress_cb, '_flush'): + if child_progress_cb and hasattr(child_progress_cb, "_flush"): try: child_progress_cb._flush() except Exception: @@ -258,7 +270,7 @@ def _run_single_child( else: status = "failed" - entry: Dict[str, Any] = { + entry: dict[str, Any] = { "task_index": task_index, "status": status, "summary": summary, @@ -284,7 +296,7 @@ def _run_single_child( finally: # Unregister child from interrupt propagation - if hasattr(parent_agent, '_active_children'): + if hasattr(parent_agent, "_active_children"): try: parent_agent._active_children.remove(child) except (ValueError, UnboundLocalError): @@ -292,11 +304,11 @@ def _run_single_child( def delegate_task( - goal: Optional[str] = None, - context: Optional[str] = None, - toolsets: Optional[List[str]] = None, - tasks: Optional[List[Dict[str, Any]]] = None, - max_iterations: Optional[int] = None, + goal: str | None = None, + context: str | None = None, + toolsets: list[str] | None = None, + tasks: list[dict[str, Any]] | None = None, + max_iterations: int | None = None, parent_agent=None, ) -> str: """ @@ -312,14 +324,11 @@ def delegate_task( return json.dumps({"error": "delegate_task requires a parent agent context."}) # Depth limit - depth = getattr(parent_agent, '_delegate_depth', 0) + depth = getattr(parent_agent, "_delegate_depth", 0) if depth >= MAX_DEPTH: - return json.dumps({ - "error": ( - f"Delegation depth limit reached ({MAX_DEPTH}). " - "Subagents cannot spawn further subagents." - ) - }) + return json.dumps( + {"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")} + ) # Load config cfg = _load_config() @@ -366,7 +375,7 @@ def delegate_task( else: # Batch -- run in parallel with per-task progress lines completed_count = 0 - spinner_ref = getattr(parent_agent, '_delegate_spinner', None) + spinner_ref = getattr(parent_agent, "_delegate_spinner", None) # Save stdout/stderr before the executor — redirect_stdout in child # threads races on sys.stdout and can leave it as devnull permanently. @@ -412,7 +421,7 @@ def delegate_task( status = entry.get("status", "?") icon = "✓" if status == "completed" else "✗" remaining = n_tasks - completed_count - completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)" + completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)" if spinner_ref: try: spinner_ref.print_above(completion_line) @@ -437,16 +446,20 @@ def delegate_task( total_duration = round(time.monotonic() - overall_start, 2) - return json.dumps({ - "results": results, - "total_duration_seconds": total_duration, - }, ensure_ascii=False) + return json.dumps( + { + "results": results, + "total_duration_seconds": total_duration, + }, + ensure_ascii=False, + ) def _load_config() -> dict: """Load delegation config from CLI_CONFIG if available.""" try: from cli import CLI_CONFIG + return CLI_CONFIG.get("delegation", {}) except Exception: return {} @@ -537,10 +550,7 @@ DELEGATE_TASK_SCHEMA = { }, "max_iterations": { "type": "integer", - "description": ( - "Max tool-calling turns per subagent (default: 50). " - "Only set lower for simple tasks." - ), + "description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."), }, }, "required": [], @@ -561,6 +571,7 @@ registry.register( toolsets=args.get("toolsets"), tasks=args.get("tasks"), max_iterations=args.get("max_iterations"), - parent_agent=kw.get("parent_agent")), + parent_agent=kw.get("parent_agent"), + ), check_fn=check_delegate_requirements, ) diff --git a/tools/environments/base.py b/tools/environments/base.py index 50bf3b2adc..43fd283b75 100644 --- a/tools/environments/base.py +++ b/tools/environments/base.py @@ -1,8 +1,8 @@ """Base class for all Hermes execution environment backends.""" -from abc import ABC, abstractmethod import os import subprocess +from abc import ABC, abstractmethod from pathlib import Path @@ -34,9 +34,9 @@ class BaseEnvironment(ABC): self.env = env or {} @abstractmethod - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: """Execute a command, return {"output": str, "returncode": int}.""" ... @@ -62,10 +62,10 @@ class BaseEnvironment(ABC): def _prepare_command(self, command: str) -> str: """Transform sudo commands if SUDO_PASSWORD is available.""" from tools.terminal_tool import _transform_sudo_command + return _transform_sudo_command(command) - def _build_run_kwargs(self, timeout: int | None, - stdin_data: str | None = None) -> dict: + def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict: """Build common subprocess.run kwargs for non-interactive execution.""" kw = { "text": True, diff --git a/tools/environments/daytona.py b/tools/environments/daytona.py index c8df198c1c..333d5563cd 100644 --- a/tools/environments/daytona.py +++ b/tools/environments/daytona.py @@ -11,7 +11,6 @@ import shlex import threading import uuid import warnings -from typing import Optional from tools.environments.base import BaseEnvironment from tools.interrupt import is_interrupted @@ -32,8 +31,8 @@ class DaytonaEnvironment(BaseEnvironment): cwd: str = "/home/daytona", timeout: int = 60, cpu: int = 1, - memory: int = 5120, # MB (hermes convention) - disk: int = 10240, # MB (Daytona platform max is 10GB) + memory: int = 5120, # MB (hermes convention) + disk: int = 10240, # MB (Daytona platform max is 10GB) persistent_filesystem: bool = True, task_id: str = "default", ): @@ -41,8 +40,8 @@ class DaytonaEnvironment(BaseEnvironment): super().__init__(cwd=cwd, timeout=timeout) from daytona import ( - Daytona, CreateSandboxFromImageParams, + Daytona, DaytonaError, Resources, SandboxState, @@ -73,13 +72,11 @@ class DaytonaEnvironment(BaseEnvironment): try: self._sandbox = self._daytona.find_one(labels=labels) self._sandbox.start() - logger.info("Daytona: resumed sandbox %s for task %s", - self._sandbox.id, task_id) + logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id) except DaytonaError: self._sandbox = None except Exception as e: - logger.warning("Daytona: failed to resume sandbox for task %s: %s", - task_id, e) + logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e) self._sandbox = None # Create a fresh sandbox if we don't have one @@ -92,8 +89,7 @@ class DaytonaEnvironment(BaseEnvironment): resources=resources, ) ) - logger.info("Daytona: created sandbox %s for task %s", - self._sandbox.id, task_id) + logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id) # Resolve cwd: detect actual home dir inside the sandbox if self._requested_cwd in ("~", "/home/daytona"): @@ -112,7 +108,7 @@ class DaytonaEnvironment(BaseEnvironment): self._sandbox.start() logger.info("Daytona: restarted sandbox %s", self._sandbox.id) - def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict: + def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict: """Run exec in a background thread with interrupt polling. The Daytona SDK's exec(timeout=...) parameter is unreliable (the @@ -130,7 +126,8 @@ class DaytonaEnvironment(BaseEnvironment): def _run(): try: response = self._sandbox.process.exec( - timed_command, cwd=cwd, + timed_command, + cwd=cwd, ) result_holder["value"] = { "output": response.result or "", @@ -169,9 +166,9 @@ class DaytonaEnvironment(BaseEnvironment): return {"error": result_holder["error"]} return result_holder["value"] - def execute(self, command: str, cwd: str = "", *, - timeout: Optional[int] = None, - stdin_data: Optional[str] = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: with self._lock: self._ensure_sandbox_ready() @@ -189,6 +186,7 @@ class DaytonaEnvironment(BaseEnvironment): if "error" in result: from daytona import DaytonaError + err = result["error"] if isinstance(err, DaytonaError): with self._lock: @@ -210,8 +208,7 @@ class DaytonaEnvironment(BaseEnvironment): try: if self._persistent: self._sandbox.stop() - logger.info("Daytona: stopped sandbox %s (filesystem preserved)", - self._sandbox.id) + logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id) else: self._daytona.delete(self._sandbox) logger.info("Daytona: deleted sandbox %s", self._sandbox.id) diff --git a/tools/environments/docker.py b/tools/environments/docker.py index faf01b2a25..b3cd35cac4 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -11,7 +11,6 @@ import subprocess import sys import threading import time -from typing import Optional from tools.environments.base import BaseEnvironment from tools.interrupt import is_interrupted @@ -19,7 +18,6 @@ from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) - # Security flags applied to every container. # The container itself is the security boundary (isolated from host). # We drop all capabilities then add back the minimum needed: @@ -28,19 +26,28 @@ logger = logging.getLogger(__name__) # Block privilege escalation and limit PIDs. # /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds). _SECURITY_ARGS = [ - "--cap-drop", "ALL", - "--cap-add", "DAC_OVERRIDE", - "--cap-add", "CHOWN", - "--cap-add", "FOWNER", - "--security-opt", "no-new-privileges", - "--pids-limit", "256", - "--tmpfs", "/tmp:rw,nosuid,size=512m", - "--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m", - "--tmpfs", "/run:rw,noexec,nosuid,size=64m", + "--cap-drop", + "ALL", + "--cap-add", + "DAC_OVERRIDE", + "--cap-add", + "CHOWN", + "--cap-add", + "FOWNER", + "--security-opt", + "no-new-privileges", + "--pids-limit", + "256", + "--tmpfs", + "/tmp:rw,nosuid,size=512m", + "--tmpfs", + "/var/tmp:rw,noexec,nosuid,size=256m", + "--tmpfs", + "/run:rw,noexec,nosuid,size=64m", ] -_storage_opt_ok: Optional[bool] = None # cached result across instances +_storage_opt_ok: bool | None = None # cached result across instances class DockerEnvironment(BaseEnvironment): @@ -74,7 +81,7 @@ class DockerEnvironment(BaseEnvironment): self._base_image = image self._persistent = persistent_filesystem self._task_id = task_id - self._container_id: Optional[str] = None + self._container_id: str | None = None logger.info(f"DockerEnvironment volumes: {volumes}") # Ensure volumes is a list (config.yaml could be malformed) if volumes is not None and not isinstance(volumes, list): @@ -105,8 +112,8 @@ class DockerEnvironment(BaseEnvironment): # mode uses tmpfs (ephemeral, fast, gone on cleanup). from tools.environments.base import get_sandbox_dir - self._workspace_dir: Optional[str] = None - self._home_dir: Optional[str] = None + self._workspace_dir: str | None = None + self._home_dir: str | None = None if self._persistent: sandbox = get_sandbox_dir() / "docker" / task_id self._workspace_dir = str(sandbox / "workspace") @@ -114,14 +121,19 @@ class DockerEnvironment(BaseEnvironment): os.makedirs(self._workspace_dir, exist_ok=True) os.makedirs(self._home_dir, exist_ok=True) writable_args = [ - "-v", f"{self._workspace_dir}:/workspace", - "-v", f"{self._home_dir}:/root", + "-v", + f"{self._workspace_dir}:/workspace", + "-v", + f"{self._home_dir}:/root", ] else: writable_args = [ - "--tmpfs", "/workspace:rw,exec,size=10g", - "--tmpfs", "/home:rw,exec,size=1g", - "--tmpfs", "/root:rw,exec,size=1g", + "--tmpfs", + "/workspace:rw,exec,size=10g", + "--tmpfs", + "/home:rw,exec,size=1g", + "--tmpfs", + "/root:rw,exec,size=1g", ] # All containers get security hardening (capabilities dropped, no privilege @@ -129,7 +141,7 @@ class DockerEnvironment(BaseEnvironment): # can install packages as needed. # User-configured volume mounts (from config.yaml docker_volumes) volume_args = [] - for vol in (volumes or []): + for vol in volumes or []: if not isinstance(vol, str): logger.warning(f"Docker volume entry is not a string: {vol!r}") continue @@ -146,7 +158,9 @@ class DockerEnvironment(BaseEnvironment): logger.info(f"Docker run_args: {all_run_args}") self._inner = _Docker( - image=image, cwd=cwd, timeout=timeout, + image=image, + cwd=cwd, + timeout=timeout, run_args=all_run_args, ) self._container_id = self._inner.container_id @@ -154,7 +168,7 @@ class DockerEnvironment(BaseEnvironment): @staticmethod def _storage_opt_supported() -> bool: """Check if Docker's storage driver supports --storage-opt size=. - + Only overlay2 on XFS with pquota supports per-container disk quotas. Ubuntu (and most distros) default to ext4, where this flag errors out. """ @@ -164,7 +178,9 @@ class DockerEnvironment(BaseEnvironment): try: result = subprocess.run( ["docker", "info", "--format", "{{.Driver}}"], - capture_output=True, text=True, timeout=10, + capture_output=True, + text=True, + timeout=10, ) driver = result.stdout.strip().lower() if driver != "overlay2": @@ -174,14 +190,15 @@ class DockerEnvironment(BaseEnvironment): # Probe by attempting a dry-ish run — the fastest reliable check. probe = subprocess.run( ["docker", "create", "--storage-opt", "size=1m", "hello-world"], - capture_output=True, text=True, timeout=15, + capture_output=True, + text=True, + timeout=15, ) if probe.returncode == 0: # Clean up the created container container_id = probe.stdout.strip() if container_id: - subprocess.run(["docker", "rm", container_id], - capture_output=True, timeout=5) + subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5) _storage_opt_ok = True else: _storage_opt_ok = False @@ -190,9 +207,9 @@ class DockerEnvironment(BaseEnvironment): logger.debug("Docker --storage-opt support: %s", _storage_opt_ok) return _storage_opt_ok - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: exec_command = self._prepare_command(command) work_dir = cwd or self.cwd effective_timeout = timeout or self.timeout @@ -218,7 +235,8 @@ class DockerEnvironment(BaseEnvironment): _output_chunks = [] proc = subprocess.Popen( cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL, text=True, ) @@ -269,6 +287,7 @@ class DockerEnvironment(BaseEnvironment): if not self._persistent: import shutil + for d in (self._workspace_dir, self._home_dir): if d: shutil.rmtree(d, ignore_errors=True) diff --git a/tools/environments/local.py b/tools/environments/local.py index e1df97b4cd..4a700f0d3c 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment): def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None): super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env) - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: from tools.terminal_tool import _interrupt_event work_dir = cwd or self.cwd or os.getcwd() @@ -172,11 +172,7 @@ class LocalEnvironment(BaseEnvironment): # Wrap with output fences so we can later extract the real # command output and discard shell init/exit noise. fenced_cmd = ( - f"printf '{_OUTPUT_FENCE}';" - f" {exec_command};" - f" __hermes_rc=$?;" - f" printf '{_OUTPUT_FENCE}';" - f" exit $__hermes_rc" + f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc" ) # Ensure PATH always includes standard dirs — systemd services # and some terminal multiplexers inherit a minimal PATH. @@ -200,12 +196,14 @@ class LocalEnvironment(BaseEnvironment): ) if stdin_data is not None: + def _write_stdin(): try: proc.stdin.write(stdin_data) proc.stdin.close() except (BrokenPipeError, OSError): pass + threading.Thread(target=_write_stdin, daemon=True).start() _output_chunks: list[str] = [] diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 84a9a6d75b..3e0af81d09 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -8,10 +8,9 @@ project files, and config changes survive across sessions. import json import logging import threading -import time import uuid from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any from tools.environments.base import BaseEnvironment from tools.interrupt import is_interrupted @@ -21,7 +20,7 @@ logger = logging.getLogger(__name__) _SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json" -def _load_snapshots() -> Dict[str, str]: +def _load_snapshots() -> dict[str, str]: """Load snapshot ID mapping from disk.""" if _SNAPSHOT_STORE.exists(): try: @@ -31,7 +30,7 @@ def _load_snapshots() -> Dict[str, str]: return {} -def _save_snapshots(data: Dict[str, str]) -> None: +def _save_snapshots(data: dict[str, str]) -> None: """Persist snapshot ID mapping to disk.""" _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) @@ -52,7 +51,7 @@ class ModalEnvironment(BaseEnvironment): image: str, cwd: str = "~", timeout: int = 60, - modal_sandbox_kwargs: Optional[Dict[str, Any]] = None, + modal_sandbox_kwargs: dict[str, Any] | None = None, persistent_filesystem: bool = True, task_id: str = "default", ): @@ -61,6 +60,7 @@ class ModalEnvironment(BaseEnvironment): if not ModalEnvironment._patches_applied: try: from environments.patches import apply_patches + apply_patches() except ImportError: pass @@ -79,6 +79,7 @@ class ModalEnvironment(BaseEnvironment): if snapshot_id: try: import modal + restored_image = modal.Image.from_id(snapshot_id) logger.info("Modal: restoring from snapshot %s", snapshot_id[:20]) except Exception as e: @@ -88,6 +89,7 @@ class ModalEnvironment(BaseEnvironment): effective_image = restored_image if restored_image else image from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment + self._inner = SwerexModalEnvironment( image=effective_image, cwd=cwd, @@ -97,9 +99,9 @@ class ModalEnvironment(BaseEnvironment): modal_sandbox_kwargs=sandbox_kwargs, ) - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: if stdin_data is not None: marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" while marker in stdin_data: @@ -139,29 +141,29 @@ class ModalEnvironment(BaseEnvironment): """Snapshot the filesystem (if persistent) then stop the sandbox.""" if self._persistent: try: - sandbox = getattr(self._inner, 'deployment', None) - sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None + sandbox = getattr(self._inner, "deployment", None) + sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None if sandbox: import asyncio + async def _snapshot(): img = await sandbox.snapshot_filesystem.aio() return img.object_id + try: snapshot_id = asyncio.run(_snapshot()) except RuntimeError: import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - snapshot_id = pool.submit( - asyncio.run, _snapshot() - ).result(timeout=60) + snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60) snapshots = _load_snapshots() snapshots[self._task_id] = snapshot_id _save_snapshots(snapshots) - logger.info("Modal: saved filesystem snapshot %s for task %s", - snapshot_id[:20], self._task_id) + logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id) except Exception as e: logger.warning("Modal: filesystem snapshot failed: %s", e) - if hasattr(self._inner, 'stop'): + if hasattr(self._inner, "stop"): self._inner.stop() diff --git a/tools/environments/singularity.py b/tools/environments/singularity.py index c5d10e9dbb..3589f5e4a3 100644 --- a/tools/environments/singularity.py +++ b/tools/environments/singularity.py @@ -10,11 +10,9 @@ import logging import os import shutil import subprocess -import tempfile import threading import uuid from pathlib import Path -from typing import Any, Dict, Optional from tools.environments.base import BaseEnvironment from tools.interrupt import is_interrupted @@ -24,7 +22,7 @@ logger = logging.getLogger(__name__) _SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json" -def _load_snapshots() -> Dict[str, str]: +def _load_snapshots() -> dict[str, str]: if _SNAPSHOT_STORE.exists(): try: return json.loads(_SNAPSHOT_STORE.read_text()) @@ -33,7 +31,7 @@ def _load_snapshots() -> Dict[str, str]: return {} -def _save_snapshots(data: Dict[str, str]) -> None: +def _save_snapshots(data: dict[str, str]) -> None: _SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True) _SNAPSHOT_STORE.write_text(json.dumps(data, indent=2)) @@ -42,6 +40,7 @@ def _save_snapshots(data: Dict[str, str]) -> None: # Singularity helpers (scratch dir, SIF cache, SIF building) # ------------------------------------------------------------------------- + def _get_scratch_dir() -> Path: """Get the best directory for Singularity sandboxes. @@ -58,6 +57,7 @@ def _get_scratch_dir() -> Path: return scratch_path from tools.environments.base import get_sandbox_dir + sandbox = get_sandbox_dir() / "singularity" scratch = Path("/scratch") @@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: Returns the path unchanged if it's already a .sif file. For docker:// URLs, checks the cache and builds if needed. """ - if image.endswith('.sif') and Path(image).exists(): + if image.endswith(".sif") and Path(image).exists(): return image - if not image.startswith('docker://'): + if not image.startswith("docker://"): return image - image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-') + image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-") cache_dir = _get_apptainer_cache_dir() sif_path = cache_dir / f"{image_name}.sif" @@ -123,7 +123,10 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: try: result = subprocess.run( [executable, "build", str(sif_path), image], - capture_output=True, text=True, timeout=600, env=env, + capture_output=True, + text=True, + timeout=600, + env=env, ) if result.returncode != 0: logger.warning("SIF build failed, falling back to docker:// URL") @@ -145,6 +148,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str: # SingularityEnvironment # ------------------------------------------------------------------------- + class SingularityEnvironment(BaseEnvironment): """Hardened Singularity/Apptainer container with resource limits and persistence. @@ -174,7 +178,7 @@ class SingularityEnvironment(BaseEnvironment): self._instance_started = False self._persistent = persistent_filesystem self._task_id = task_id - self._overlay_dir: Optional[Path] = None + self._overlay_dir: Path | None = None # Resource limits self._cpu = cpu @@ -215,14 +219,13 @@ class SingularityEnvironment(BaseEnvironment): if result.returncode != 0: raise RuntimeError(f"Failed to start instance: {result.stderr}") self._instance_started = True - logger.info("Singularity instance %s started (persistent=%s)", - self.instance_id, self._persistent) + logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent) except subprocess.TimeoutExpired: raise RuntimeError("Instance start timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: if not self._instance_started: return {"output": "Instance not started", "returncode": -1} @@ -235,16 +238,16 @@ class SingularityEnvironment(BaseEnvironment): exec_command = f"cd {work_dir} && {exec_command}" work_dir = "/tmp" - cmd = [self.executable, "exec", "--pwd", work_dir, - f"instance://{self.instance_id}", - "bash", "-c", exec_command] + cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command] try: import time as _time + _output_chunks = [] proc = subprocess.Popen( cmd, - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL, text=True, ) @@ -295,7 +298,9 @@ class SingularityEnvironment(BaseEnvironment): try: subprocess.run( [self.executable, "instance", "stop", self.instance_id], - capture_output=True, text=True, timeout=30, + capture_output=True, + text=True, + timeout=30, ) logger.info("Singularity instance %s stopped", self.instance_id) except Exception as e: diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 02acce244c..c4216beef3 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -24,8 +24,7 @@ class SSHEnvironment(BaseEnvironment): and a remote kill is attempted over the ControlMaster socket. """ - def __init__(self, host: str, user: str, cwd: str = "~", - timeout: int = 60, port: int = 22, key_path: str = ""): + def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""): super().__init__(cwd=cwd, timeout=timeout) self.host = host self.user = user @@ -65,12 +64,12 @@ class SSHEnvironment(BaseEnvironment): except subprocess.TimeoutExpired: raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out") - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + def execute( + self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None + ) -> dict: work_dir = cwd or self.cwd exec_command = self._prepare_command(command) - wrapped = f'cd {work_dir} && {exec_command}' + wrapped = f"cd {work_dir} && {exec_command}" effective_timeout = timeout or self.timeout cmd = self._build_ssh_command() @@ -136,8 +135,7 @@ class SSHEnvironment(BaseEnvironment): def cleanup(self): if self.control_socket.exists(): try: - cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", - "-O", "exit", f"{self.user}@{self.host}"] + cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"] subprocess.run(cmd, capture_output=True, timeout=5) except (OSError, subprocess.SubprocessError): pass diff --git a/tools/file_operations.py b/tools/file_operations.py index b3b8f15309..f845303e67 100644 --- a/tools/file_operations.py +++ b/tools/file_operations.py @@ -11,29 +11,27 @@ so we wrap the terminal backend's execute() interface to provide a unified file Usage: from tools.file_operations import ShellFileOperations from tools.terminal_tool import _active_environments - + # Get file operations for a terminal environment file_ops = ShellFileOperations(terminal_env) - + # Read a file result = file_ops.read_file("/path/to/file.py") - + # Write a file result = file_ops.write_file("/path/to/new.py", "print('hello')") - + # Search for content result = file_ops.search("TODO", path=".", file_glob="*.py") """ +import difflib import os import re -import json -import difflib from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, List, Dict, Any, Tuple from pathlib import Path - +from typing import Any # --------------------------------------------------------------------------- # Write-path deny list — blocks writes to sensitive system/credential files @@ -42,7 +40,8 @@ from pathlib import Path _HOME = str(Path.home()) WRITE_DENIED_PATHS = { - os.path.realpath(p) for p in [ + os.path.realpath(p) + for p in [ os.path.join(_HOME, ".ssh", "authorized_keys"), os.path.join(_HOME, ".ssh", "id_rsa"), os.path.join(_HOME, ".ssh", "id_ed25519"), @@ -64,7 +63,8 @@ WRITE_DENIED_PATHS = { } WRITE_DENIED_PREFIXES = [ - os.path.realpath(p) + os.sep for p in [ + os.path.realpath(p) + os.sep + for p in [ os.path.join(_HOME, ".ssh"), os.path.join(_HOME, ".aws"), os.path.join(_HOME, ".gnupg"), @@ -90,22 +90,24 @@ def _is_write_denied(path: str) -> bool: # Result Data Classes # ============================================================================= + @dataclass class ReadResult: """Result from reading a file.""" + content: str = "" total_lines: int = 0 file_size: int = 0 truncated: bool = False - hint: Optional[str] = None + hint: str | None = None is_binary: bool = False is_image: bool = False - base64_content: Optional[str] = None - mime_type: Optional[str] = None - dimensions: Optional[str] = None # For images: "WIDTHxHEIGHT" - error: Optional[str] = None - similar_files: List[str] = field(default_factory=list) - + base64_content: str | None = None + mime_type: str | None = None + dimensions: str | None = None # For images: "WIDTHxHEIGHT" + error: str | None = None + similar_files: list[str] = field(default_factory=list) + def to_dict(self) -> dict: return {k: v for k, v in self.__dict__.items() if v is not None and v != []} @@ -113,11 +115,12 @@ class ReadResult: @dataclass class WriteResult: """Result from writing a file.""" + bytes_written: int = 0 dirs_created: bool = False - error: Optional[str] = None - warning: Optional[str] = None - + error: str | None = None + warning: str | None = None + def to_dict(self) -> dict: return {k: v for k, v in self.__dict__.items() if v is not None} @@ -125,14 +128,15 @@ class WriteResult: @dataclass class PatchResult: """Result from patching a file.""" + success: bool = False diff: str = "" - files_modified: List[str] = field(default_factory=list) - files_created: List[str] = field(default_factory=list) - files_deleted: List[str] = field(default_factory=list) - lint: Optional[Dict[str, Any]] = None - error: Optional[str] = None - + files_modified: list[str] = field(default_factory=list) + files_created: list[str] = field(default_factory=list) + files_deleted: list[str] = field(default_factory=list) + lint: dict[str, Any] | None = None + error: str | None = None + def to_dict(self) -> dict: result = {"success": self.success} if self.diff: @@ -153,6 +157,7 @@ class PatchResult: @dataclass class SearchMatch: """A single search match.""" + path: str line_number: int content: str @@ -162,20 +167,18 @@ class SearchMatch: @dataclass class SearchResult: """Result from searching.""" - matches: List[SearchMatch] = field(default_factory=list) - files: List[str] = field(default_factory=list) - counts: Dict[str, int] = field(default_factory=dict) + + matches: list[SearchMatch] = field(default_factory=list) + files: list[str] = field(default_factory=list) + counts: dict[str, int] = field(default_factory=dict) total_count: int = 0 truncated: bool = False - error: Optional[str] = None - + error: str | None = None + def to_dict(self) -> dict: result = {"total_count": self.total_count} if self.matches: - result["matches"] = [ - {"path": m.path, "line": m.line_number, "content": m.content} - for m in self.matches - ] + result["matches"] = [{"path": m.path, "line": m.line_number, "content": m.content} for m in self.matches] if self.files: result["files"] = self.files if self.counts: @@ -190,23 +193,22 @@ class SearchResult: @dataclass class LintResult: """Result from linting a file.""" + success: bool = True skipped: bool = False output: str = "" message: str = "" - + def to_dict(self) -> dict: if self.skipped: return {"status": "skipped", "message": self.message} - return { - "status": "ok" if self.success else "error", - "output": self.output - } + return {"status": "ok" if self.success else "error", "output": self.output} @dataclass class ExecuteResult: """Result from executing a shell command.""" + stdout: str = "" exit_code: int = 0 @@ -215,34 +217,42 @@ class ExecuteResult: # Abstract Interface # ============================================================================= + class FileOperations(ABC): """Abstract interface for file operations across terminal backends.""" - + @abstractmethod def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult: """Read a file with pagination support.""" ... - + @abstractmethod def write_file(self, path: str, content: str) -> WriteResult: """Write content to a file, creating directories as needed.""" ... - + @abstractmethod - def patch_replace(self, path: str, old_string: str, new_string: str, - replace_all: bool = False) -> PatchResult: + def patch_replace(self, path: str, old_string: str, new_string: str, replace_all: bool = False) -> PatchResult: """Replace text in a file using fuzzy matching.""" ... - + @abstractmethod def patch_v4a(self, patch_content: str) -> PatchResult: """Apply a V4A format patch.""" ... - + @abstractmethod - def search(self, pattern: str, path: str = ".", target: str = "content", - file_glob: Optional[str] = None, limit: int = 50, offset: int = 0, - output_mode: str = "content", context: int = 0) -> SearchResult: + def search( + self, + pattern: str, + path: str = ".", + target: str = "content", + file_glob: str | None = None, + limit: int = 50, + offset: int = 0, + output_mode: str = "content", + context: int = 0, + ) -> SearchResult: """Search for content or files.""" ... @@ -254,33 +264,76 @@ class FileOperations(ABC): # Binary file extensions (fast path check) BINARY_EXTENSIONS = { # Images - '.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico', '.tiff', '.tif', - '.svg', # SVG is text but often treated as binary + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".ico", + ".tiff", + ".tif", + ".svg", # SVG is text but often treated as binary # Audio/Video - '.mp3', '.mp4', '.wav', '.avi', '.mov', '.mkv', '.flac', '.ogg', '.webm', + ".mp3", + ".mp4", + ".wav", + ".avi", + ".mov", + ".mkv", + ".flac", + ".ogg", + ".webm", # Archives - '.zip', '.tar', '.gz', '.bz2', '.xz', '.7z', '.rar', + ".zip", + ".tar", + ".gz", + ".bz2", + ".xz", + ".7z", + ".rar", # Documents - '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', + ".pdf", + ".doc", + ".docx", + ".xls", + ".xlsx", + ".ppt", + ".pptx", # Compiled/Binary - '.exe', '.dll', '.so', '.dylib', '.o', '.a', '.pyc', '.pyo', '.class', - '.wasm', '.bin', + ".exe", + ".dll", + ".so", + ".dylib", + ".o", + ".a", + ".pyc", + ".pyo", + ".class", + ".wasm", + ".bin", # Fonts - '.ttf', '.otf', '.woff', '.woff2', '.eot', + ".ttf", + ".otf", + ".woff", + ".woff2", + ".eot", # Other - '.db', '.sqlite', '.sqlite3', + ".db", + ".sqlite", + ".sqlite3", } # Image extensions (subset of binary that we can return as base64) -IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico'} +IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".ico"} # Linters by file extension LINTERS = { - '.py': 'python -m py_compile {file} 2>&1', - '.js': 'node --check {file} 2>&1', - '.ts': 'npx tsc --noEmit {file} 2>&1', - '.go': 'go vet {file} 2>&1', - '.rs': 'rustfmt --check {file} 2>&1', + ".py": "python -m py_compile {file} 2>&1", + ".js": "node --check {file} 2>&1", + ".ts": "npx tsc --noEmit {file} 2>&1", + ".go": "go vet {file} 2>&1", + ".rs": "rustfmt --check {file} 2>&1", } # Max limits for read operations @@ -292,15 +345,15 @@ MAX_FILE_SIZE = 50 * 1024 # 50KB class ShellFileOperations(FileOperations): """ File operations implemented via shell commands. - + Works with ANY terminal backend that has execute(command, cwd) method. This includes local, docker, singularity, ssh, modal, and daytona environments. """ - + def __init__(self, terminal_env, cwd: str = None): """ Initialize file operations with a terminal environment. - + Args: terminal_env: Any object with execute(command, cwd) method. Returns {"output": str, "returncode": int} @@ -311,164 +364,159 @@ class ShellFileOperations(FileOperations): # IMPORTANT: do NOT fall back to os.getcwd() -- that's the HOST's local # path which doesn't exist inside container/cloud backends (modal, docker). # If nothing provides a cwd, use "/" as a safe universal default. - self.cwd = cwd or getattr(terminal_env, 'cwd', None) or \ - getattr(getattr(terminal_env, 'config', None), 'cwd', None) or "/" - + self.cwd = ( + cwd + or getattr(terminal_env, "cwd", None) + or getattr(getattr(terminal_env, "config", None), "cwd", None) + or "/" + ) + # Cache for command availability checks - self._command_cache: Dict[str, bool] = {} - - def _exec(self, command: str, cwd: str = None, timeout: int = None, - stdin_data: str = None) -> ExecuteResult: + self._command_cache: dict[str, bool] = {} + + def _exec(self, command: str, cwd: str = None, timeout: int = None, stdin_data: str = None) -> ExecuteResult: """Execute command via terminal backend. - + Args: stdin_data: If provided, piped to the process's stdin instead of embedding in the command string. Bypasses ARG_MAX. """ kwargs = {} if timeout: - kwargs['timeout'] = timeout + kwargs["timeout"] = timeout if stdin_data is not None: - kwargs['stdin_data'] = stdin_data - + kwargs["stdin_data"] = stdin_data + result = self.env.execute(command, cwd=cwd or self.cwd, **kwargs) - return ExecuteResult( - stdout=result.get("output", ""), - exit_code=result.get("returncode", 0) - ) - + return ExecuteResult(stdout=result.get("output", ""), exit_code=result.get("returncode", 0)) + def _has_command(self, cmd: str) -> bool: """Check if a command exists in the environment (cached).""" if cmd not in self._command_cache: result = self._exec(f"command -v {cmd} >/dev/null 2>&1 && echo 'yes'") - self._command_cache[cmd] = result.stdout.strip() == 'yes' + self._command_cache[cmd] = result.stdout.strip() == "yes" return self._command_cache[cmd] - + def _is_likely_binary(self, path: str, content_sample: str = None) -> bool: """ Check if a file is likely binary. - + Uses extension check (fast) + content analysis (fallback). """ ext = os.path.splitext(path)[1].lower() if ext in BINARY_EXTENSIONS: return True - + # Content analysis: >30% non-printable chars = binary if content_sample: if not content_sample: return False - non_printable = sum(1 for c in content_sample[:1000] - if ord(c) < 32 and c not in '\n\r\t') + non_printable = sum(1 for c in content_sample[:1000] if ord(c) < 32 and c not in "\n\r\t") return non_printable / min(len(content_sample), 1000) > 0.30 - + return False - + def _is_image(self, path: str) -> bool: """Check if file is an image we can return as base64.""" ext = os.path.splitext(path)[1].lower() return ext in IMAGE_EXTENSIONS - + def _add_line_numbers(self, content: str, start_line: int = 1) -> str: """Add line numbers to content in LINE_NUM|CONTENT format.""" - lines = content.split('\n') + lines = content.split("\n") numbered = [] for i, line in enumerate(lines, start=start_line): # Truncate long lines if len(line) > MAX_LINE_LENGTH: line = line[:MAX_LINE_LENGTH] + "... [truncated]" numbered.append(f"{i:6d}|{line}") - return '\n'.join(numbered) - + return "\n".join(numbered) + def _expand_path(self, path: str) -> str: """ Expand shell-style paths like ~ and ~user to absolute paths. - + This must be done BEFORE shell escaping, since ~ doesn't expand inside single quotes. """ if not path: return path - + # Handle ~ and ~user - if path.startswith('~'): + if path.startswith("~"): # Get home directory via the terminal environment result = self._exec("echo $HOME") if result.exit_code == 0 and result.stdout.strip(): home = result.stdout.strip() - if path == '~': + if path == "~": return home - elif path.startswith('~/'): + elif path.startswith("~/"): return home + path[1:] # Replace ~ with home # ~username format - extract and validate username before # letting shell expand it (prevent shell injection via # paths like "~; rm -rf /"). rest = path[1:] # strip leading ~ - slash_idx = rest.find('/') + slash_idx = rest.find("/") username = rest[:slash_idx] if slash_idx >= 0 else rest - if username and re.fullmatch(r'[a-zA-Z0-9._-]+', username): + if username and re.fullmatch(r"[a-zA-Z0-9._-]+", username): expand_result = self._exec(f"echo {path}") if expand_result.exit_code == 0 and expand_result.stdout.strip(): return expand_result.stdout.strip() - + return path - + def _escape_shell_arg(self, arg: str) -> str: """Escape a string for safe use in shell commands.""" # Use single quotes and escape any single quotes in the string return "'" + arg.replace("'", "'\"'\"'") + "'" - + def _unified_diff(self, old_content: str, new_content: str, filename: str) -> str: """Generate unified diff between old and new content.""" old_lines = old_content.splitlines(keepends=True) new_lines = new_content.splitlines(keepends=True) - diff = difflib.unified_diff( - old_lines, new_lines, - fromfile=f"a/{filename}", - tofile=f"b/{filename}" - ) - return ''.join(diff) - + diff = difflib.unified_diff(old_lines, new_lines, fromfile=f"a/{filename}", tofile=f"b/{filename}") + return "".join(diff) + # ========================================================================= # READ Implementation # ========================================================================= - + def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult: """ Read a file with pagination, binary detection, and line numbers. - + Args: path: File path (absolute or relative to cwd) offset: Line number to start from (1-indexed, default 1) limit: Maximum lines to return (default 500, max 2000) - + Returns: ReadResult with content, metadata, or error info """ # Expand ~ and other shell paths path = self._expand_path(path) - + # Clamp limit limit = min(limit, MAX_LINES) - + # Check if file exists and get size (wc -c is POSIX, works on Linux + macOS) stat_cmd = f"wc -c < {self._escape_shell_arg(path)} 2>/dev/null" stat_result = self._exec(stat_cmd) - + if stat_result.exit_code != 0: # File not found - try to suggest similar files return self._suggest_similar_files(path) - + try: file_size = int(stat_result.stdout.strip()) except ValueError: file_size = 0 - + # Check if file is too large if file_size > MAX_FILE_SIZE: # Still try to read, but warn pass - + # Images are never inlined — redirect to the vision tool if self._is_image(path): return ReadResult( @@ -480,26 +528,26 @@ class ShellFileOperations(FileOperations): "Use vision_analyze with this file path to inspect the image contents." ), ) - + # Read a sample to check for binary content sample_cmd = f"head -c 1000 {self._escape_shell_arg(path)} 2>/dev/null" sample_result = self._exec(sample_cmd) - + if self._is_likely_binary(path, sample_result.stdout): return ReadResult( is_binary=True, file_size=file_size, - error="Binary file - cannot display as text. Use appropriate tools to handle this file type." + error="Binary file - cannot display as text. Use appropriate tools to handle this file type.", ) - + # Read with pagination using sed end_line = offset + limit - 1 read_cmd = f"sed -n '{offset},{end_line}p' {self._escape_shell_arg(path)}" read_result = self._exec(read_cmd) - + if read_result.exit_code != 0: return ReadResult(error=f"Failed to read file: {read_result.stdout}") - + # Get total line count wc_cmd = f"wc -l < {self._escape_shell_arg(path)}" wc_result = self._exec(wc_cmd) @@ -507,21 +555,21 @@ class ShellFileOperations(FileOperations): total_lines = int(wc_result.stdout.strip()) except ValueError: total_lines = 0 - + # Check if truncated truncated = total_lines > end_line hint = None if truncated: hint = f"Use offset={end_line + 1} to continue reading (showing {offset}-{end_line} of {total_lines} lines)" - + return ReadResult( content=self._add_line_numbers(read_result.stdout, offset), total_lines=total_lines, file_size=file_size, truncated=truncated, - hint=hint + hint=hint, ) - + # Images larger than this are too expensive to inline as base64 in the # conversation context. Return metadata only and suggest vision_analyze. MAX_IMAGE_BYTES = 512 * 1024 # 512 KB @@ -535,7 +583,7 @@ class ShellFileOperations(FileOperations): file_size = int(stat_result.stdout.strip()) except ValueError: file_size = 0 - + if file_size > self.MAX_IMAGE_BYTES: return ReadResult( is_image=True, @@ -546,78 +594,75 @@ class ShellFileOperations(FileOperations): "Use vision_analyze to inspect the image, or reference it by path." ), ) - + # Get base64 content b64_cmd = f"base64 -w 0 {self._escape_shell_arg(path)} 2>/dev/null" b64_result = self._exec(b64_cmd, timeout=30) - + if b64_result.exit_code != 0: return ReadResult( - is_image=True, - is_binary=True, - file_size=file_size, - error=f"Failed to read image: {b64_result.stdout}" + is_image=True, is_binary=True, file_size=file_size, error=f"Failed to read image: {b64_result.stdout}" ) - + # Try to get dimensions (requires ImageMagick) dimensions = None - if self._has_command('identify'): + if self._has_command("identify"): dim_cmd = f"identify -format '%wx%h' {self._escape_shell_arg(path)} 2>/dev/null" dim_result = self._exec(dim_cmd) if dim_result.exit_code == 0: dimensions = dim_result.stdout.strip() - + # Determine MIME type from extension ext = os.path.splitext(path)[1].lower() mime_types = { - '.png': 'image/png', - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.gif': 'image/gif', - '.webp': 'image/webp', - '.bmp': 'image/bmp', - '.ico': 'image/x-icon', + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + ".ico": "image/x-icon", } - mime_type = mime_types.get(ext, 'application/octet-stream') - + mime_type = mime_types.get(ext, "application/octet-stream") + return ReadResult( is_image=True, is_binary=True, file_size=file_size, base64_content=b64_result.stdout, mime_type=mime_type, - dimensions=dimensions + dimensions=dimensions, ) - + def _suggest_similar_files(self, path: str) -> ReadResult: """Suggest similar files when the requested file is not found.""" # Get directory and filename dir_path = os.path.dirname(path) or "." filename = os.path.basename(path) - + # List files in directory ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -20" ls_result = self._exec(ls_cmd) - + similar = [] if ls_result.exit_code == 0 and ls_result.stdout.strip(): - files = ls_result.stdout.strip().split('\n') + files = ls_result.stdout.strip().split("\n") # Simple similarity: files that share some characters with the target for f in files: # Check if filenames share significant overlap common = set(filename.lower()) & set(f.lower()) if len(common) >= len(filename) * 0.5: # 50% character overlap similar.append(os.path.join(dir_path, f)) - + return ReadResult( error=f"File not found: {path}", - similar_files=similar[:5] # Limit to 5 suggestions + similar_files=similar[:5], # Limit to 5 suggestions ) - + # ========================================================================= # WRITE Implementation # ========================================================================= - + def write_file(self, path: str, content: str) -> WriteResult: """ Write content to a file, creating parent directories as needed. @@ -643,41 +688,37 @@ class ShellFileOperations(FileOperations): # Create parent directories parent = os.path.dirname(path) dirs_created = False - + if parent: mkdir_cmd = f"mkdir -p {self._escape_shell_arg(parent)}" mkdir_result = self._exec(mkdir_cmd) if mkdir_result.exit_code == 0: dirs_created = True - + # Write via stdin pipe — content bypasses shell arg parsing entirely, # so there's no ARG_MAX limit regardless of file size. write_cmd = f"cat > {self._escape_shell_arg(path)}" write_result = self._exec(write_cmd, stdin_data=content) - + if write_result.exit_code != 0: return WriteResult(error=f"Failed to write file: {write_result.stdout}") - + # Get bytes written (wc -c is POSIX, works on Linux + macOS) stat_cmd = f"wc -c < {self._escape_shell_arg(path)} 2>/dev/null" stat_result = self._exec(stat_cmd) - + try: bytes_written = int(stat_result.stdout.strip()) except ValueError: - bytes_written = len(content.encode('utf-8')) - - return WriteResult( - bytes_written=bytes_written, - dirs_created=dirs_created - ) - + bytes_written = len(content.encode("utf-8")) + + return WriteResult(bytes_written=bytes_written, dirs_created=dirs_created) + # ========================================================================= # PATCH Implementation (Replace Mode) # ========================================================================= - - def patch_replace(self, path: str, old_string: str, new_string: str, - replace_all: bool = False) -> PatchResult: + + def patch_replace(self, path: str, old_string: str, new_string: str, replace_all: bool = False) -> PatchResult: """ Replace text in a file using fuzzy matching. @@ -700,47 +741,42 @@ class ShellFileOperations(FileOperations): # Read current content read_cmd = f"cat {self._escape_shell_arg(path)} 2>/dev/null" read_result = self._exec(read_cmd) - + if read_result.exit_code != 0: return PatchResult(error=f"Failed to read file: {path}") - + content = read_result.stdout - + # Import and use fuzzy matching from tools.fuzzy_match import fuzzy_find_and_replace - - new_content, match_count, error = fuzzy_find_and_replace( - content, old_string, new_string, replace_all - ) - + + new_content, match_count, error = fuzzy_find_and_replace(content, old_string, new_string, replace_all) + if error: return PatchResult(error=error) - + if match_count == 0: return PatchResult(error=f"Could not find match for old_string in {path}") - + # Write back write_result = self.write_file(path, new_content) if write_result.error: return PatchResult(error=f"Failed to write changes: {write_result.error}") - + # Generate diff diff = self._unified_diff(content, new_content, path) - + # Auto-lint lint_result = self._check_lint(path) - + return PatchResult( - success=True, - diff=diff, - files_modified=[path], - lint=lint_result.to_dict() if lint_result else None + success=True, diff=diff, files_modified=[path], lint=lint_result.to_dict() if lint_result else None ) - + def patch_v4a(self, patch_content: str) -> PatchResult: """ Apply a V4A format patch. - + V4A format: *** Begin Patch *** Update File: path/to/file.py @@ -749,66 +785,71 @@ class ShellFileOperations(FileOperations): -removed line +added line *** End Patch - + Args: patch_content: V4A format patch string - + Returns: PatchResult with changes made """ # Import patch parser - from tools.patch_parser import parse_v4a_patch, apply_v4a_operations - + from tools.patch_parser import apply_v4a_operations, parse_v4a_patch + operations, parse_error = parse_v4a_patch(patch_content) if parse_error: return PatchResult(error=f"Failed to parse patch: {parse_error}") - + # Apply operations result = apply_v4a_operations(operations, self) return result - + def _check_lint(self, path: str) -> LintResult: """ Run syntax check on a file after editing. - + Args: path: File path to lint - + Returns: LintResult with status and any errors """ ext = os.path.splitext(path)[1].lower() - + if ext not in LINTERS: return LintResult(skipped=True, message=f"No linter for {ext} files") - + # Check if linter command is available linter_cmd = LINTERS[ext] # Extract the base command (first word) base_cmd = linter_cmd.split()[0] - + if not self._has_command(base_cmd): return LintResult(skipped=True, message=f"{base_cmd} not available") - + # Run linter cmd = linter_cmd.format(file=self._escape_shell_arg(path)) result = self._exec(cmd, timeout=30) - - return LintResult( - success=result.exit_code == 0, - output=result.stdout.strip() if result.stdout.strip() else "" - ) - + + return LintResult(success=result.exit_code == 0, output=result.stdout.strip() if result.stdout.strip() else "") + # ========================================================================= # SEARCH Implementation # ========================================================================= - - def search(self, pattern: str, path: str = ".", target: str = "content", - file_glob: Optional[str] = None, limit: int = 50, offset: int = 0, - output_mode: str = "content", context: int = 0) -> SearchResult: + + def search( + self, + pattern: str, + path: str = ".", + target: str = "content", + file_glob: str | None = None, + limit: int = 50, + offset: int = 0, + output_mode: str = "content", + context: int = 0, + ) -> SearchResult: """ Search for content or files. - + Args: pattern: Regex (for content) or glob pattern (for files) path: Directory/file to search (default: cwd) @@ -818,280 +859,259 @@ class ShellFileOperations(FileOperations): offset: Skip first N results output_mode: "content", "files_only", or "count" context: Lines of context around matches - + Returns: SearchResult with matches or file list """ # Expand ~ and other shell paths path = self._expand_path(path) - + # Validate that the path exists before searching check = self._exec(f"test -e {self._escape_shell_arg(path)} && echo exists || echo not_found") if "not_found" in check.stdout: return SearchResult( - error=f"Path not found: {path}. Verify the path exists (use 'terminal' to check).", - total_count=0 + error=f"Path not found: {path}. Verify the path exists (use 'terminal' to check).", total_count=0 ) - + if target == "files": return self._search_files(pattern, path, limit, offset) else: - return self._search_content(pattern, path, file_glob, limit, offset, - output_mode, context) - + return self._search_content(pattern, path, file_glob, limit, offset, output_mode, context) + def _search_files(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult: """Search for files by name pattern (glob-like).""" # Check if find is available (not on Windows without Git Bash/WSL) - if not self._has_command('find'): + if not self._has_command("find"): return SearchResult( - error="File search requires 'find' command. " - "On Windows, use Git Bash, WSL, or install Unix tools." + error="File search requires 'find' command. On Windows, use Git Bash, WSL, or install Unix tools." ) - + # Auto-prepend **/ for recursive search if not already present - if not pattern.startswith('**/') and '/' not in pattern: + if not pattern.startswith("**/") and "/" not in pattern: search_pattern = pattern else: - search_pattern = pattern.split('/')[-1] - + search_pattern = pattern.split("/")[-1] + # Use find with modification time sorting # -printf '%T@ %p\n' outputs: timestamp path # sort -rn sorts by timestamp descending (newest first) - cmd = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \ - f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn | tail -n +{offset + 1} | head -n {limit}" - + cmd = ( + f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " + f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn | tail -n +{offset + 1} | head -n {limit}" + ) + result = self._exec(cmd, timeout=60) - + if not result.stdout.strip(): # Try without -printf (BSD find compatibility -- macOS) - cmd_simple = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \ - f"2>/dev/null | head -n {limit + offset} | tail -n +{offset + 1}" + cmd_simple = ( + f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " + f"2>/dev/null | head -n {limit + offset} | tail -n +{offset + 1}" + ) result = self._exec(cmd_simple, timeout=60) - + files = [] - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if not line: continue # Parse "timestamp path" format - parts = line.split(' ', 1) - if len(parts) == 2 and parts[0].replace('.', '').isdigit(): + parts = line.split(" ", 1) + if len(parts) == 2 and parts[0].replace(".", "").isdigit(): files.append(parts[1]) else: files.append(line) - - return SearchResult( - files=files, - total_count=len(files) - ) - - def _search_content(self, pattern: str, path: str, file_glob: Optional[str], - limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + + return SearchResult(files=files, total_count=len(files)) + + def _search_content( + self, pattern: str, path: str, file_glob: str | None, limit: int, offset: int, output_mode: str, context: int + ) -> SearchResult: """Search for content inside files (grep-like).""" # Try ripgrep first (fast), fallback to grep (slower but works) - if self._has_command('rg'): - return self._search_with_rg(pattern, path, file_glob, limit, offset, - output_mode, context) - elif self._has_command('grep'): - return self._search_with_grep(pattern, path, file_glob, limit, offset, - output_mode, context) + if self._has_command("rg"): + return self._search_with_rg(pattern, path, file_glob, limit, offset, output_mode, context) + elif self._has_command("grep"): + return self._search_with_grep(pattern, path, file_glob, limit, offset, output_mode, context) else: # Neither rg nor grep available (Windows without Git Bash, etc.) return SearchResult( error="Content search requires ripgrep (rg) or grep. " - "Install ripgrep: https://github.com/BurntSushi/ripgrep#installation" + "Install ripgrep: https://github.com/BurntSushi/ripgrep#installation" ) - - def _search_with_rg(self, pattern: str, path: str, file_glob: Optional[str], - limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + + def _search_with_rg( + self, pattern: str, path: str, file_glob: str | None, limit: int, offset: int, output_mode: str, context: int + ) -> SearchResult: """Search using ripgrep.""" cmd_parts = ["rg", "--line-number", "--no-heading", "--with-filename"] - + # Add context if requested if context > 0: cmd_parts.extend(["-C", str(context)]) - + # Add file glob filter (must be quoted to prevent shell expansion) if file_glob: cmd_parts.extend(["--glob", self._escape_shell_arg(file_glob)]) - + # Output mode handling if output_mode == "files_only": cmd_parts.append("-l") # Files only elif output_mode == "count": cmd_parts.append("-c") # Count per file - + # Add pattern and path cmd_parts.append(self._escape_shell_arg(pattern)) cmd_parts.append(self._escape_shell_arg(path)) - + # Fetch extra rows so we can report the true total before slicing. # For context mode, rg emits separator lines ("--") between groups, # so we grab generously and filter in Python. fetch_limit = limit + offset + 200 if context > 0 else limit + offset cmd_parts.extend(["|", "head", "-n", str(fetch_limit)]) - + cmd = " ".join(cmd_parts) result = self._exec(cmd, timeout=60) - + # rg exit codes: 0=matches found, 1=no matches, 2=error if result.exit_code == 2 and not result.stdout.strip(): - error_msg = result.stderr.strip() if hasattr(result, 'stderr') and result.stderr else "Search error" + error_msg = result.stderr.strip() if hasattr(result, "stderr") and result.stderr else "Search error" return SearchResult(error=f"Search failed: {error_msg}", total_count=0) - + # Parse results based on output mode if output_mode == "files_only": - all_files = [f for f in result.stdout.strip().split('\n') if f] + all_files = [f for f in result.stdout.strip().split("\n") if f] total = len(all_files) - page = all_files[offset:offset + limit] + page = all_files[offset : offset + limit] return SearchResult(files=page, total_count=total) - + elif output_mode == "count": counts = {} - for line in result.stdout.strip().split('\n'): - if ':' in line: - parts = line.rsplit(':', 1) + for line in result.stdout.strip().split("\n"): + if ":" in line: + parts = line.rsplit(":", 1) if len(parts) == 2: try: counts[parts[0]] = int(parts[1]) except ValueError: pass return SearchResult(counts=counts, total_count=sum(counts.values())) - + else: # Parse content matches and context lines. # rg match lines: "file:lineno:content" (colon separator) # rg context lines: "file-lineno-content" (dash separator) # rg group seps: "--" matches = [] - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if not line or line == "--": continue - + # Try match line first (colon-separated: file:line:content) - parts = line.split(':', 2) + parts = line.split(":", 2) if len(parts) >= 3: try: - matches.append(SearchMatch( - path=parts[0], - line_number=int(parts[1]), - content=parts[2][:500] - )) + matches.append(SearchMatch(path=parts[0], line_number=int(parts[1]), content=parts[2][:500])) continue except ValueError: pass - + # Try context line (dash-separated: file-line-content) # Only attempt if context was requested to avoid false positives if context > 0: - parts = line.split('-', 2) + parts = line.split("-", 2) if len(parts) >= 3: try: - matches.append(SearchMatch( - path=parts[0], - line_number=int(parts[1]), - content=parts[2][:500] - )) + matches.append( + SearchMatch(path=parts[0], line_number=int(parts[1]), content=parts[2][:500]) + ) except ValueError: pass - + total = len(matches) - page = matches[offset:offset + limit] - return SearchResult( - matches=page, - total_count=total, - truncated=total > offset + limit - ) - - def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str], - limit: int, offset: int, output_mode: str, context: int) -> SearchResult: + page = matches[offset : offset + limit] + return SearchResult(matches=page, total_count=total, truncated=total > offset + limit) + + def _search_with_grep( + self, pattern: str, path: str, file_glob: str | None, limit: int, offset: int, output_mode: str, context: int + ) -> SearchResult: """Fallback search using grep.""" cmd_parts = ["grep", "-rnH"] # -H forces filename even for single-file searches - + # Add context if requested if context > 0: cmd_parts.extend(["-C", str(context)]) - + # Add file pattern filter (must be quoted to prevent shell expansion) if file_glob: cmd_parts.extend(["--include", self._escape_shell_arg(file_glob)]) - + # Output mode handling if output_mode == "files_only": cmd_parts.append("-l") elif output_mode == "count": cmd_parts.append("-c") - + # Add pattern and path cmd_parts.append(self._escape_shell_arg(pattern)) cmd_parts.append(self._escape_shell_arg(path)) - + # Fetch generously so we can compute total before slicing fetch_limit = limit + offset + (200 if context > 0 else 0) cmd_parts.extend(["|", "head", "-n", str(fetch_limit)]) - + cmd = " ".join(cmd_parts) result = self._exec(cmd, timeout=60) - + # grep exit codes: 0=matches found, 1=no matches, 2=error if result.exit_code == 2 and not result.stdout.strip(): - error_msg = result.stderr.strip() if hasattr(result, 'stderr') and result.stderr else "Search error" + error_msg = result.stderr.strip() if hasattr(result, "stderr") and result.stderr else "Search error" return SearchResult(error=f"Search failed: {error_msg}", total_count=0) - + if output_mode == "files_only": - all_files = [f for f in result.stdout.strip().split('\n') if f] + all_files = [f for f in result.stdout.strip().split("\n") if f] total = len(all_files) - page = all_files[offset:offset + limit] + page = all_files[offset : offset + limit] return SearchResult(files=page, total_count=total) - + elif output_mode == "count": counts = {} - for line in result.stdout.strip().split('\n'): - if ':' in line: - parts = line.rsplit(':', 1) + for line in result.stdout.strip().split("\n"): + if ":" in line: + parts = line.rsplit(":", 1) if len(parts) == 2: try: counts[parts[0]] = int(parts[1]) except ValueError: pass return SearchResult(counts=counts, total_count=sum(counts.values())) - + else: # grep match lines: "file:lineno:content" (colon) # grep context lines: "file-lineno-content" (dash) # grep group seps: "--" matches = [] - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if not line or line == "--": continue - - parts = line.split(':', 2) + + parts = line.split(":", 2) if len(parts) >= 3: try: - matches.append(SearchMatch( - path=parts[0], - line_number=int(parts[1]), - content=parts[2][:500] - )) + matches.append(SearchMatch(path=parts[0], line_number=int(parts[1]), content=parts[2][:500])) continue except ValueError: pass - + if context > 0: - parts = line.split('-', 2) + parts = line.split("-", 2) if len(parts) >= 3: try: - matches.append(SearchMatch( - path=parts[0], - line_number=int(parts[1]), - content=parts[2][:500] - )) + matches.append( + SearchMatch(path=parts[0], line_number=int(parts[1]), content=parts[2][:500]) + ) except ValueError: pass - + total = len(matches) - page = matches[offset:offset + limit] - return SearchResult( - matches=page, - total_count=total, - truncated=total > offset + limit - ) + page = matches[offset : offset + limit] + return SearchResult(matches=page, total_count=total, truncated=total > offset + limit) diff --git a/tools/file_tools.py b/tools/file_tools.py index 5ba098bd73..309742916b 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -3,11 +3,10 @@ import json import logging -import os import threading -from typing import Optional -from tools.file_operations import ShellFileOperations + from agent.redact import redact_sensitive_text +from tools.file_operations import ShellFileOperations logger = logging.getLogger(__name__) @@ -25,14 +24,19 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations: Thread-safe: uses the same per-task creation locks as terminal_tool to prevent duplicate sandbox creation from concurrent tool calls. """ - from tools.terminal_tool import ( - _active_environments, _env_lock, _create_environment, - _get_env_config, _last_activity, _start_cleanup_thread, - _check_disk_usage_warning, - _creation_locks, _creation_locks_lock, - ) import time + from tools.terminal_tool import ( + _active_environments, + _create_environment, + _creation_locks, + _creation_locks_lock, + _env_lock, + _get_env_config, + _last_activity, + _start_cleanup_thread, + ) + # Fast path: check cache -- but also verify the underlying environment # is still alive (it may have been killed by the cleanup thread). with _file_ops_lock: @@ -143,17 +147,23 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str: result = file_ops.write_file(path, content) return json.dumps(result.to_dict(), ensure_ascii=False) except Exception as e: - print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True) + print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True) return json.dumps({"error": str(e)}, ensure_ascii=False) -def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, - new_string: str = None, replace_all: bool = False, patch: str = None, - task_id: str = "default") -> str: +def patch_tool( + mode: str = "replace", + path: str = None, + old_string: str = None, + new_string: str = None, + replace_all: bool = False, + patch: str = None, + task_id: str = "default", +) -> str: """Patch a file using replace mode or V4A patch format.""" try: file_ops = _get_file_ops(task_id) - + if mode == "replace": if not path: return json.dumps({"error": "path required"}) @@ -166,7 +176,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, result = file_ops.patch_v4a(patch) else: return json.dumps({"error": f"Unknown mode: {mode}"}) - + result_dict = result.to_dict() result_json = json.dumps(result_dict, ensure_ascii=False) # Hint when old_string not found — saves iterations where the agent @@ -178,20 +188,33 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, return json.dumps({"error": str(e)}, ensure_ascii=False) -def search_tool(pattern: str, target: str = "content", path: str = ".", - file_glob: str = None, limit: int = 50, offset: int = 0, - output_mode: str = "content", context: int = 0, - task_id: str = "default") -> str: +def search_tool( + pattern: str, + target: str = "content", + path: str = ".", + file_glob: str = None, + limit: int = 50, + offset: int = 0, + output_mode: str = "content", + context: int = 0, + task_id: str = "default", +) -> str: """Search for content or files.""" try: file_ops = _get_file_ops(task_id) result = file_ops.search( - pattern=pattern, path=path, target=target, file_glob=file_glob, - limit=limit, offset=offset, output_mode=output_mode, context=context + pattern=pattern, + path=path, + target=target, + file_glob=file_glob, + limit=limit, + offset=offset, + output_mode=output_mode, + context=context, ) - if hasattr(result, 'matches'): + if hasattr(result, "matches"): for m in result.matches: - if hasattr(m, 'content') and m.content: + if hasattr(m, "content") and m.content: m.content = redact_sensitive_text(m.content) result_dict = result.to_dict() result_json = json.dumps(result_dict, ensure_ascii=False) @@ -209,7 +232,7 @@ FILE_TOOLS = [ {"name": "read_file", "function": read_file_tool}, {"name": "write_file", "function": write_file_tool}, {"name": "patch", "function": patch_tool}, - {"name": "search_files", "function": search_tool} + {"name": "search_files", "function": search_tool}, ] @@ -227,8 +250,10 @@ from tools.registry import registry def _check_file_reqs(): """Lazy wrapper to avoid circular import with tools/__init__.py.""" from tools import check_file_requirements + return check_file_requirements() + READ_FILE_SCHEMA = { "name": "read_file", "description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.", @@ -236,11 +261,21 @@ READ_FILE_SCHEMA = { "type": "object", "properties": { "path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"}, - "offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1}, - "limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000} + "offset": { + "type": "integer", + "description": "Line number to start reading from (1-indexed, default: 1)", + "default": 1, + "minimum": 1, + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read (default: 500, max: 2000)", + "default": 500, + "maximum": 2000, + }, }, - "required": ["path"] - } + "required": ["path"], + }, } WRITE_FILE_SCHEMA = { @@ -249,11 +284,14 @@ WRITE_FILE_SCHEMA = { "parameters": { "type": "object", "properties": { - "path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"}, - "content": {"type": "string", "description": "Complete content to write to the file"} + "path": { + "type": "string", + "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)", + }, + "content": {"type": "string", "description": "Complete content to write to the file"}, }, - "required": ["path", "content"] - } + "required": ["path", "content"], + }, } PATCH_SCHEMA = { @@ -262,15 +300,33 @@ PATCH_SCHEMA = { "parameters": { "type": "object", "properties": { - "mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"}, + "mode": { + "type": "string", + "enum": ["replace", "patch"], + "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", + "default": "replace", + }, "path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"}, - "old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."}, - "new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."}, - "replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False}, - "patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"} + "old_string": { + "type": "string", + "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.", + }, + "new_string": { + "type": "string", + "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.", + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences instead of requiring a unique match (default: false)", + "default": False, + }, + "patch": { + "type": "string", + "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch", + }, }, - "required": ["mode"] - } + "required": ["mode"], + }, } SEARCH_FILES_SCHEMA = { @@ -279,23 +335,57 @@ SEARCH_FILES_SCHEMA = { "parameters": { "type": "object", "properties": { - "pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"}, - "target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"}, - "path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."}, - "file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"}, - "limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50}, - "offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0}, - "output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"}, - "context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0} + "pattern": { + "type": "string", + "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search", + }, + "target": { + "type": "string", + "enum": ["content", "files"], + "description": "'content' searches inside file contents, 'files' searches for files by name", + "default": "content", + }, + "path": { + "type": "string", + "description": "Directory or file to search in (default: current working directory)", + "default": ".", + }, + "file_glob": { + "type": "string", + "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)", + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return (default: 50)", + "default": 50, + }, + "offset": { + "type": "integer", + "description": "Skip first N results for pagination (default: 0)", + "default": 0, + }, + "output_mode": { + "type": "string", + "enum": ["content", "files_only", "count"], + "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", + "default": "content", + }, + "context": { + "type": "integer", + "description": "Number of context lines before and after each match (grep mode only)", + "default": 0, + }, }, - "required": ["pattern"] - } + "required": ["pattern"], + }, } def _handle_read_file(args, **kw): tid = kw.get("task_id") or "default" - return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid) + return read_file_tool( + path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid + ) def _handle_write_file(args, **kw): @@ -306,9 +396,14 @@ def _handle_write_file(args, **kw): def _handle_patch(args, **kw): tid = kw.get("task_id") or "default" return patch_tool( - mode=args.get("mode", "replace"), path=args.get("path"), - old_string=args.get("old_string"), new_string=args.get("new_string"), - replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid) + mode=args.get("mode", "replace"), + path=args.get("path"), + old_string=args.get("old_string"), + new_string=args.get("new_string"), + replace_all=args.get("replace_all", False), + patch=args.get("patch"), + task_id=tid, + ) def _handle_search_files(args, **kw): @@ -317,12 +412,29 @@ def _handle_search_files(args, **kw): raw_target = args.get("target", "content") target = target_map.get(raw_target, raw_target) return search_tool( - pattern=args.get("pattern", ""), target=target, path=args.get("path", "."), - file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0), - output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid) + pattern=args.get("pattern", ""), + target=target, + path=args.get("path", "."), + file_glob=args.get("file_glob"), + limit=args.get("limit", 50), + offset=args.get("offset", 0), + output_mode=args.get("output_mode", "content"), + context=args.get("context", 0), + task_id=tid, + ) -registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs) -registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs) +registry.register( + name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs +) +registry.register( + name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs +) registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs) -registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs) +registry.register( + name="search_files", + toolset="file", + schema=SEARCH_FILES_SCHEMA, + handler=_handle_search_files, + check_fn=_check_file_reqs, +) diff --git a/tools/fuzzy_match.py b/tools/fuzzy_match.py index bc8e344036..64dd0b8cb6 100644 --- a/tools/fuzzy_match.py +++ b/tools/fuzzy_match.py @@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode): Usage: from tools.fuzzy_match import fuzzy_find_and_replace - + new_content, match_count, error = fuzzy_find_and_replace( content="def foo():\\n pass", old_string="def foo():", @@ -29,21 +29,22 @@ Usage: """ import re -from typing import Tuple, Optional, List, Callable +from collections.abc import Callable from difflib import SequenceMatcher -def fuzzy_find_and_replace(content: str, old_string: str, new_string: str, - replace_all: bool = False) -> Tuple[str, int, Optional[str]]: +def fuzzy_find_and_replace( + content: str, old_string: str, new_string: str, replace_all: bool = False +) -> tuple[str, int, str | None]: """ Find and replace text using a chain of increasingly fuzzy matching strategies. - + Args: content: The file content to search in old_string: The text to find new_string: The replacement text replace_all: If True, replace all occurrences; if False, require uniqueness - + Returns: Tuple of (new_content, match_count, error_message) - If successful: (modified_content, number_of_replacements, None) @@ -51,12 +52,12 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str, """ if not old_string: return content, 0, "old_string cannot be empty" - + if old_string == new_string: return content, 0, "old_string and new_string are identical" - + # Try each matching strategy in order - strategies: List[Tuple[str, Callable]] = [ + strategies: list[tuple[str, Callable]] = [ ("exact", _strategy_exact), ("line_trimmed", _strategy_line_trimmed), ("whitespace_normalized", _strategy_whitespace_normalized), @@ -66,46 +67,50 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str, ("block_anchor", _strategy_block_anchor), ("context_aware", _strategy_context_aware), ] - + for strategy_name, strategy_fn in strategies: matches = strategy_fn(content, old_string) - + if matches: # Found matches with this strategy if len(matches) > 1 and not replace_all: - return content, 0, ( - f"Found {len(matches)} matches for old_string. " - f"Provide more context to make it unique, or use replace_all=True." + return ( + content, + 0, + ( + f"Found {len(matches)} matches for old_string. " + f"Provide more context to make it unique, or use replace_all=True." + ), ) - + # Perform replacement new_content = _apply_replacements(content, matches, new_string) return new_content, len(matches), None - + # No strategy found a match return content, 0, "Could not find a match for old_string in the file" -def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str: +def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str: """ Apply replacements at the given positions. - + Args: content: Original content matches: List of (start, end) positions to replace new_string: Replacement text - + Returns: Content with replacements applied """ # Sort matches by position (descending) to replace from end to start # This preserves positions of earlier matches sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True) - + result = content for start, end in sorted_matches: result = result[:start] + new_string + result[end:] - + return result @@ -113,7 +118,8 @@ def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string # Matching Strategies # ============================================================================= -def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]: + +def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]: """Strategy 1: Exact string match.""" matches = [] start = 0 @@ -126,206 +132,201 @@ def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]: return matches -def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 2: Match with line-by-line whitespace trimming. - + Strips leading/trailing whitespace from each line before matching. """ # Normalize pattern and content by trimming each line - pattern_lines = [line.strip() for line in pattern.split('\n')] - pattern_normalized = '\n'.join(pattern_lines) - - content_lines = content.split('\n') + pattern_lines = [line.strip() for line in pattern.split("\n")] + pattern_normalized = "\n".join(pattern_lines) + + content_lines = content.split("\n") content_normalized_lines = [line.strip() for line in content_lines] - + # Build mapping from normalized positions back to original positions - return _find_normalized_matches( - content, content_lines, content_normalized_lines, - pattern, pattern_normalized - ) + return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized) -def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 3: Collapse multiple whitespace to single space. """ + def normalize(s): # Collapse multiple spaces/tabs to single space, preserve newlines - return re.sub(r'[ \t]+', ' ', s) - + return re.sub(r"[ \t]+", " ", s) + pattern_normalized = normalize(pattern) content_normalized = normalize(content) - + # Find in normalized, map back to original matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized) - + if not matches_in_normalized: return [] - + # Map positions back to original content return _map_normalized_positions(content, content_normalized, matches_in_normalized) -def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 4: Ignore indentation differences entirely. - + Strips all leading whitespace from lines before matching. """ + def strip_indent(s): - return '\n'.join(line.lstrip() for line in s.split('\n')) - + return "\n".join(line.lstrip() for line in s.split("\n")) + pattern_stripped = strip_indent(pattern) - - content_lines = content.split('\n') + + content_lines = content.split("\n") content_stripped_lines = [line.lstrip() for line in content_lines] - pattern_lines = [line.lstrip() for line in pattern.split('\n')] - - return _find_normalized_matches( - content, content_lines, content_stripped_lines, - pattern, '\n'.join(pattern_lines) - ) + pattern_lines = [line.lstrip() for line in pattern.split("\n")] + + return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines)) -def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 5: Convert escape sequences to actual characters. - + Handles \\n -> newline, \\t -> tab, etc. """ + def unescape(s): # Convert common escape sequences - return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r') - + return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r") + pattern_unescaped = unescape(pattern) - + if pattern_unescaped == pattern: # No escapes to convert, skip this strategy return [] - + return _strategy_exact(content, pattern_unescaped) -def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 6: Trim whitespace from first and last lines only. - + Useful when the pattern boundaries have whitespace differences. """ - pattern_lines = pattern.split('\n') + pattern_lines = pattern.split("\n") if not pattern_lines: return [] - + # Trim only first and last lines pattern_lines[0] = pattern_lines[0].strip() if len(pattern_lines) > 1: pattern_lines[-1] = pattern_lines[-1].strip() - - modified_pattern = '\n'.join(pattern_lines) - - content_lines = content.split('\n') - + + modified_pattern = "\n".join(pattern_lines) + + content_lines = content.split("\n") + # Search through content for matching block matches = [] pattern_line_count = len(pattern_lines) - + for i in range(len(content_lines) - pattern_line_count + 1): - block_lines = content_lines[i:i + pattern_line_count] - + block_lines = content_lines[i : i + pattern_line_count] + # Trim first and last of this block check_lines = block_lines.copy() check_lines[0] = check_lines[0].strip() if len(check_lines) > 1: check_lines[-1] = check_lines[-1].strip() - - if '\n'.join(check_lines) == modified_pattern: + + if "\n".join(check_lines) == modified_pattern: # Found match - calculate original positions start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1 if end_pos >= len(content): end_pos = len(content) matches.append((start_pos, end_pos)) - + return matches -def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 7: Match by anchoring on first and last lines. - + If first and last lines match exactly, accept middle with 70% similarity. """ - pattern_lines = pattern.split('\n') + pattern_lines = pattern.split("\n") if len(pattern_lines) < 2: return [] # Need at least 2 lines for anchoring - + first_line = pattern_lines[0].strip() last_line = pattern_lines[-1].strip() - - content_lines = content.split('\n') + + content_lines = content.split("\n") matches = [] - + pattern_line_count = len(pattern_lines) - + for i in range(len(content_lines) - pattern_line_count + 1): # Check if first and last lines match - if (content_lines[i].strip() == first_line and - content_lines[i + pattern_line_count - 1].strip() == last_line): - + if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line: # Check middle similarity if pattern_line_count <= 2: # Only first and last, they match similarity = 1.0 else: - content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1]) - pattern_middle = '\n'.join(pattern_lines[1:-1]) + content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1]) + pattern_middle = "\n".join(pattern_lines[1:-1]) similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio() - + if similarity >= 0.70: # Calculate positions start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1 if end_pos >= len(content): end_pos = len(content) matches.append((start_pos, end_pos)) - + return matches -def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]: +def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]: """ Strategy 8: Line-by-line similarity with 50% threshold. - + Finds blocks where at least 50% of lines have high similarity. """ - pattern_lines = pattern.split('\n') - content_lines = content.split('\n') - + pattern_lines = pattern.split("\n") + content_lines = content.split("\n") + if not pattern_lines: return [] - + matches = [] pattern_line_count = len(pattern_lines) - + for i in range(len(content_lines) - pattern_line_count + 1): - block_lines = content_lines[i:i + pattern_line_count] - + block_lines = content_lines[i : i + pattern_line_count] + # Calculate line-by-line similarity high_similarity_count = 0 for p_line, c_line in zip(pattern_lines, block_lines): sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio() if sim >= 0.80: high_similarity_count += 1 - + # Need at least 50% of lines to have high similarity if high_similarity_count >= len(pattern_lines) * 0.5: start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 + end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1 if end_pos >= len(content): end_pos = len(content) matches.append((start_pos, end_pos)) - + return matches @@ -333,74 +334,76 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]] # Helper Functions # ============================================================================= -def _find_normalized_matches(content: str, content_lines: List[str], - content_normalized_lines: List[str], - pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]: + +def _find_normalized_matches( + content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str +) -> list[tuple[int, int]]: """ Find matches in normalized content and map back to original positions. - + Args: content: Original content string content_lines: Original content split by lines content_normalized_lines: Normalized content lines pattern: Original pattern pattern_normalized: Normalized pattern - + Returns: List of (start, end) positions in the original content """ - pattern_norm_lines = pattern_normalized.split('\n') + pattern_norm_lines = pattern_normalized.split("\n") num_pattern_lines = len(pattern_norm_lines) - + matches = [] - + for i in range(len(content_normalized_lines) - num_pattern_lines + 1): # Check if this block matches - block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines]) - + block = "\n".join(content_normalized_lines[i : i + num_pattern_lines]) + if block == pattern_normalized: # Found a match - calculate original positions start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1 - + end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1 + # Handle case where end is past content if end_pos >= len(content): end_pos = len(content) - + matches.append((start_pos, end_pos)) - + return matches -def _map_normalized_positions(original: str, normalized: str, - normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]: +def _map_normalized_positions( + original: str, normalized: str, normalized_matches: list[tuple[int, int]] +) -> list[tuple[int, int]]: """ Map positions from normalized string back to original. - + This is a best-effort mapping that works for whitespace normalization. """ if not normalized_matches: return [] - + # Build character mapping from normalized to original orig_to_norm = [] # orig_to_norm[i] = position in normalized - + orig_idx = 0 norm_idx = 0 - + while orig_idx < len(original) and norm_idx < len(normalized): if original[orig_idx] == normalized[norm_idx]: orig_to_norm.append(norm_idx) orig_idx += 1 norm_idx += 1 - elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ': + elif original[orig_idx] in " \t" and normalized[norm_idx] == " ": # Original has space/tab, normalized collapsed to space orig_to_norm.append(norm_idx) orig_idx += 1 # Don't advance norm_idx yet - wait until all whitespace consumed - if orig_idx < len(original) and original[orig_idx] not in ' \t': + if orig_idx < len(original) and original[orig_idx] not in " \t": norm_idx += 1 - elif original[orig_idx] in ' \t': + elif original[orig_idx] in " \t": # Extra whitespace in original orig_to_norm.append(norm_idx) orig_idx += 1 @@ -408,21 +411,21 @@ def _map_normalized_positions(original: str, normalized: str, # Mismatch - shouldn't happen with our normalization orig_to_norm.append(norm_idx) orig_idx += 1 - + # Fill remaining while orig_idx < len(original): orig_to_norm.append(len(normalized)) orig_idx += 1 - + # Reverse mapping: for each normalized position, find original range norm_to_orig_start = {} norm_to_orig_end = {} - + for orig_pos, norm_pos in enumerate(orig_to_norm): if norm_pos not in norm_to_orig_start: norm_to_orig_start[norm_pos] = orig_pos norm_to_orig_end[norm_pos] = orig_pos - + # Map matches original_matches = [] for norm_start, norm_end in normalized_matches: @@ -432,17 +435,17 @@ def _map_normalized_positions(original: str, normalized: str, else: # Find nearest orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start) - + # Find original end if norm_end - 1 in norm_to_orig_end: orig_end = norm_to_orig_end[norm_end - 1] + 1 else: orig_end = orig_start + (norm_end - norm_start) - + # Expand to include trailing whitespace that was normalized - while orig_end < len(original) and original[orig_end] in ' \t': + while orig_end < len(original) and original[orig_end] in " \t": orig_end += 1 - + original_matches.append((orig_start, min(orig_end, len(original)))) - + return original_matches diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py index a9077cff35..6719638ac8 100644 --- a/tools/homeassistant_tool.py +++ b/tools/homeassistant_tool.py @@ -15,7 +15,7 @@ import json import logging import os import re -from typing import Any, Dict, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -35,23 +35,26 @@ def _get_config(): _HASS_TOKEN or os.getenv("HASS_TOKEN", ""), ) + # Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1") _ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$") # Service domains blocked for security -- these allow arbitrary code/command # execution on the HA host or enable SSRF attacks on the local network. # HA provides zero service-level access control; all safety must be in our layer. -_BLOCKED_DOMAINS = frozenset({ - "shell_command", # arbitrary shell commands as root in HA container - "command_line", # sensors/switches that execute shell commands - "python_script", # sandboxed but can escalate via hass.services.call() - "pyscript", # scripting integration with broader access - "hassio", # addon control, host shutdown/reboot, stdin to containers - "rest_command", # HTTP requests from HA server (SSRF vector) -}) +_BLOCKED_DOMAINS = frozenset( + { + "shell_command", # arbitrary shell commands as root in HA container + "command_line", # sensors/switches that execute shell commands + "python_script", # sandboxed but can escalate via hass.services.call() + "pyscript", # scripting integration with broader access + "hassio", # addon control, host shutdown/reboot, stdin to containers + "rest_command", # HTTP requests from HA server (SSRF vector) + } +) -def _get_headers(token: str = "") -> Dict[str, str]: +def _get_headers(token: str = "") -> dict[str, str]: """Return authorization headers for HA REST API.""" if not token: _, token = _get_config() @@ -65,11 +68,12 @@ def _get_headers(token: str = "") -> Dict[str, str]: # Async helpers (called from sync handlers via run_until_complete) # --------------------------------------------------------------------------- + def _filter_and_summarize( states: list, - domain: Optional[str] = None, - area: Optional[str] = None, -) -> Dict[str, Any]: + domain: str | None = None, + area: str | None = None, +) -> dict[str, Any]: """Filter raw HA states by domain/area and return a compact summary.""" if domain: states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")] @@ -77,26 +81,29 @@ def _filter_and_summarize( if area: area_lower = area.lower() states = [ - s for s in states + s + for s in states if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower() or area_lower in (s.get("attributes", {}).get("area", "") or "").lower() ] entities = [] for s in states: - entities.append({ - "entity_id": s["entity_id"], - "state": s["state"], - "friendly_name": s.get("attributes", {}).get("friendly_name", ""), - }) + entities.append( + { + "entity_id": s["entity_id"], + "state": s["state"], + "friendly_name": s.get("attributes", {}).get("friendly_name", ""), + } + ) return {"count": len(entities), "entities": entities} async def _async_list_entities( - domain: Optional[str] = None, - area: Optional[str] = None, -) -> Dict[str, Any]: + domain: str | None = None, + area: str | None = None, +) -> dict[str, Any]: """Fetch entity states from HA and optionally filter by domain/area.""" import aiohttp @@ -110,7 +117,7 @@ async def _async_list_entities( return _filter_and_summarize(states, domain, area) -async def _async_get_state(entity_id: str) -> Dict[str, Any]: +async def _async_get_state(entity_id: str) -> dict[str, Any]: """Fetch detailed state of a single entity.""" import aiohttp @@ -131,11 +138,11 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]: def _build_service_payload( - entity_id: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + entity_id: str | None = None, + data: dict[str, Any] | None = None, +) -> dict[str, Any]: """Build the JSON payload for a HA service call.""" - payload: Dict[str, Any] = {} + payload: dict[str, Any] = {} if data: payload.update(data) # entity_id parameter takes precedence over data["entity_id"] @@ -148,15 +155,17 @@ def _parse_service_response( domain: str, service: str, result: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Parse HA service call response into a structured result.""" affected = [] if isinstance(result, list): for s in result: - affected.append({ - "entity_id": s.get("entity_id", ""), - "state": s.get("state", ""), - }) + affected.append( + { + "entity_id": s.get("entity_id", ""), + "state": s.get("state", ""), + } + ) return { "success": True, @@ -168,9 +177,9 @@ def _parse_service_response( async def _async_call_service( domain: str, service: str, - entity_id: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + entity_id: str | None = None, + data: dict[str, Any] | None = None, +) -> dict[str, Any]: """Call a Home Assistant service.""" import aiohttp @@ -178,15 +187,17 @@ async def _async_call_service( url = f"{hass_url}/api/services/{domain}/{service}" payload = _build_service_payload(entity_id, data) - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( url, headers=_get_headers(hass_token), json=payload, timeout=aiohttp.ClientTimeout(total=15), - ) as resp: - resp.raise_for_status() - result = await resp.json() + ) as resp, + ): + resp.raise_for_status() + result = await resp.json() return _parse_service_response(domain, service, result) @@ -195,6 +206,7 @@ async def _async_call_service( # Sync wrappers (handler signature: (args, **kw) -> str) # --------------------------------------------------------------------------- + def _run_async(coro): """Run an async coroutine from a sync handler.""" try: @@ -205,6 +217,7 @@ def _run_async(coro): if loop and loop.is_running(): # Already inside an event loop -- create a new thread import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: future = pool.submit(asyncio.run, coro) return future.result(timeout=30) @@ -247,10 +260,12 @@ def _handle_call_service(args: dict, **kw) -> str: return json.dumps({"error": "Missing required parameters: domain and service"}) if domain in _BLOCKED_DOMAINS: - return json.dumps({ - "error": f"Service domain '{domain}' is blocked for security. " - f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}" - }) + return json.dumps( + { + "error": f"Service domain '{domain}' is blocked for security. " + f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}" + } + ) entity_id = args.get("entity_id") if entity_id and not _ENTITY_ID_RE.match(entity_id): @@ -269,7 +284,8 @@ def _handle_call_service(args: dict, **kw) -> str: # List services # --------------------------------------------------------------------------- -async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]: + +async def _async_list_services(domain: str | None = None) -> dict[str, Any]: """Fetch available services from HA and optionally filter by domain.""" import aiohttp @@ -290,13 +306,10 @@ async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]: d = svc_domain.get("domain", "") domain_services = {} for svc_name, svc_info in svc_domain.get("services", {}).items(): - svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")} + svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")} fields = svc_info.get("fields", {}) if fields: - svc_entry["fields"] = { - k: v.get("description", "") for k, v in fields.items() - if isinstance(v, dict) - } + svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)} domain_services[svc_name] = svc_entry result.append({"domain": d, "services": domain_services}) @@ -318,6 +331,7 @@ def _handle_list_services(args: dict, **kw) -> str: # Availability check # --------------------------------------------------------------------------- + def _check_ha_available() -> bool: """Tool is only available when HASS_TOKEN is set.""" return bool(os.getenv("HASS_TOKEN")) @@ -369,8 +383,7 @@ HA_GET_STATE_SCHEMA = { "entity_id": { "type": "string", "description": ( - "The entity ID to query (e.g. 'light.living_room', " - "'climate.thermostat', 'sensor.temperature')." + "The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')." ), }, }, @@ -392,8 +405,7 @@ HA_LIST_SERVICES_SCHEMA = { "domain": { "type": "string", "description": ( - "Filter by domain (e.g. 'light', 'climate', 'switch'). " - "Omit to list services for all domains." + "Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains." ), }, }, @@ -428,8 +440,7 @@ HA_CALL_SERVICE_SCHEMA = { "entity_id": { "type": "string", "description": ( - "Target entity ID (e.g. 'light.living_room'). " - "Some services (like scene.turn_on) may not need this." + "Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this." ), }, "data": { diff --git a/tools/honcho_tools.py b/tools/honcho_tools.py index a701c6468f..c6573723d1 100644 --- a/tools/honcho_tools.py +++ b/tools/honcho_tools.py @@ -65,6 +65,7 @@ HONCHO_TOOL_SCHEMA = { # ── Tool handler ── + def _handle_query_user_context(args: dict, **kw) -> str: """Execute the Honcho context query.""" query = args.get("query", "") @@ -84,6 +85,7 @@ def _handle_query_user_context(args: dict, **kw) -> str: # ── Availability check ── + def _check_honcho_available() -> bool: """Tool is only available when Honcho is active.""" return _session_manager is not None and _session_key is not None diff --git a/tools/image_generation_tool.py b/tools/image_generation_tool.py index 3789f38e70..1943d2f8b1 100644 --- a/tools/image_generation_tool.py +++ b/tools/image_generation_tool.py @@ -2,7 +2,7 @@ """ Image Generation Tools Module -This module provides image generation tools using FAL.ai's FLUX 2 Pro model with +This module provides image generation tools using FAL.ai's FLUX 2 Pro model with automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality. Available tools: @@ -19,7 +19,7 @@ Features: Usage: from image_generation_tool import image_generate_tool import asyncio - + # Generate and automatically upscale an image result = await image_generate_tool( prompt="A serene mountain landscape with cherry blossoms", @@ -28,12 +28,14 @@ Usage: ) """ +import datetime import json import logging import os -import datetime -from typing import Dict, Any, Optional, Union +from typing import Any + import fal_client + from tools.debug_helpers import DebugSession logger = logging.getLogger(__name__) @@ -51,11 +53,7 @@ ENABLE_SAFETY_CHECKER = False SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive) # Aspect ratio mapping - simplified choices for model to select -ASPECT_RATIO_MAP = { - "landscape": "landscape_16_9", - "square": "square_hd", - "portrait": "portrait_16_9" -} +ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"} VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys()) # Configuration for automatic upscaling @@ -70,9 +68,7 @@ UPSCALER_GUIDANCE_SCALE = 4 UPSCALER_NUM_INFERENCE_STEPS = 18 # Valid parameter values for validation based on FLUX 2 Pro documentation -VALID_IMAGE_SIZES = [ - "square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9" -] +VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"] VALID_OUTPUT_FORMATS = ["jpeg", "png"] VALID_ACCELERATION_MODES = ["none", "regular", "high"] @@ -80,16 +76,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG") def _validate_parameters( - image_size: Union[str, Dict[str, int]], + image_size: str | dict[str, int], num_inference_steps: int, guidance_scale: float, num_images: int, output_format: str, - acceleration: str = "none" -) -> Dict[str, Any]: + acceleration: str = "none", +) -> dict[str, Any]: """ Validate and normalize image generation parameters for FLUX 2 Pro model. - + Args: image_size: Either a preset string or custom size dict num_inference_steps: Number of inference steps @@ -97,15 +93,15 @@ def _validate_parameters( num_images: Number of images to generate output_format: Output format for images acceleration: Acceleration mode for generation speed - + Returns: Dict[str, Any]: Validated and normalized parameters - + Raises: ValueError: If any parameter is invalid """ validated = {} - + # Validate image_size if isinstance(image_size, str): if image_size not in VALID_IMAGE_SIZES: @@ -123,52 +119,52 @@ def _validate_parameters( validated["image_size"] = image_size else: raise ValueError("image_size must be either a preset string or a dict with width/height") - + # Validate num_inference_steps if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100: raise ValueError("num_inference_steps must be an integer between 1 and 100") validated["num_inference_steps"] = num_inference_steps - + # Validate guidance_scale (FLUX 2 Pro default is 4.5) if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0: raise ValueError("guidance_scale must be a number between 0.1 and 20.0") validated["guidance_scale"] = float(guidance_scale) - + # Validate num_images if not isinstance(num_images, int) or num_images < 1 or num_images > 4: raise ValueError("num_images must be an integer between 1 and 4") validated["num_images"] = num_images - + # Validate output_format if output_format not in VALID_OUTPUT_FORMATS: raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}") validated["output_format"] = output_format - + # Validate acceleration if acceleration not in VALID_ACCELERATION_MODES: raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}") validated["acceleration"] = acceleration - + return validated -def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]: +def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]: """ Upscale an image using FAL.ai's Clarity Upscaler. - + Uses the synchronous fal_client API to avoid event loop lifecycle issues when called from threaded contexts (e.g. gateway thread pool). - + Args: image_url (str): URL of the image to upscale original_prompt (str): Original prompt used to generate the image - + Returns: Dict[str, Any]: Upscaled image data or None if upscaling fails """ try: logger.info("Upscaling image with Clarity Upscaler...") - + # Prepare arguments for upscaler upscaler_arguments = { "image_url": image_url, @@ -179,35 +175,36 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]: "resemblance": UPSCALER_RESEMBLANCE, "guidance_scale": UPSCALER_GUIDANCE_SCALE, "num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS, - "enable_safety_checker": UPSCALER_SAFETY_CHECKER + "enable_safety_checker": UPSCALER_SAFETY_CHECKER, } - + # Use sync API — fal_client.submit() uses httpx.Client (no event loop). # The async API (submit_async) caches a global httpx.AsyncClient via # @cached_property, which breaks when asyncio.run() destroys the loop # between calls (gateway thread-pool pattern). - handler = fal_client.submit( - UPSCALER_MODEL, - arguments=upscaler_arguments - ) - + handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments) + # Get the upscaled result (sync — blocks until done) result = handler.get() - + if result and "image" in result: upscaled_image = result["image"] - logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown')) + logger.info( + "Image upscaled successfully to %sx%s", + upscaled_image.get("width", "unknown"), + upscaled_image.get("height", "unknown"), + ) return { "url": upscaled_image["url"], "width": upscaled_image.get("width", 0), "height": upscaled_image.get("height", 0), "upscaled": True, - "upscale_factor": UPSCALER_FACTOR + "upscale_factor": UPSCALER_FACTOR, } else: logger.error("Upscaler returned invalid response") return None - + except Exception as e: logger.error("Error upscaling image: %s", e) return None @@ -220,16 +217,16 @@ def image_generate_tool( guidance_scale: float = DEFAULT_GUIDANCE_SCALE, num_images: int = DEFAULT_NUM_IMAGES, output_format: str = DEFAULT_OUTPUT_FORMAT, - seed: Optional[int] = None + seed: int | None = None, ) -> str: """ Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling. - + Uses the synchronous fal_client API to avoid event loop lifecycle issues. The async API's global httpx.AsyncClient (cached via @cached_property) breaks when asyncio.run() destroys and recreates event loops between calls, which happens in the gateway's thread-pool pattern. - + Args: prompt (str): The text prompt describing the desired image aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape") @@ -238,7 +235,7 @@ def image_generate_tool( num_images (int): Number of images to generate (1-4, default: 1) output_format (str): Image format "jpeg" or "png" (default: "png") seed (Optional[int]): Random seed for reproducible results (optional) - + Returns: str: JSON string containing minimal generation results: { @@ -252,7 +249,7 @@ def image_generate_tool( logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO) aspect_ratio_lower = DEFAULT_ASPECT_RATIO image_size = ASPECT_RATIO_MAP[aspect_ratio_lower] - + debug_call_data = { "parameters": { "prompt": prompt, @@ -262,32 +259,32 @@ def image_generate_tool( "guidance_scale": guidance_scale, "num_images": num_images, "output_format": output_format, - "seed": seed + "seed": seed, }, "error": None, "success": False, "images_generated": 0, - "generation_time": 0 + "generation_time": 0, } - + start_time = datetime.datetime.now() - + try: logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80]) - + # Validate prompt if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0: raise ValueError("Prompt is required and must be a non-empty string") - + # Check API key availability if not os.getenv("FAL_KEY"): raise ValueError("FAL_KEY environment variable not set") - + # Validate other parameters validated_params = _validate_parameters( image_size, num_inference_steps, guidance_scale, num_images, output_format, "none" ) - + # Prepare arguments for FAL.ai FLUX 2 Pro API arguments = { "prompt": prompt.strip(), @@ -298,51 +295,44 @@ def image_generate_tool( "output_format": validated_params["output_format"], "enable_safety_checker": ENABLE_SAFETY_CHECKER, "safety_tolerance": SAFETY_TOLERANCE, - "sync_mode": True # Use sync mode for immediate results + "sync_mode": True, # Use sync mode for immediate results } - + # Add seed if provided if seed is not None and isinstance(seed, int): arguments["seed"] = seed - + logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...") logger.info(" Model: %s", DEFAULT_MODEL) logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size) - logger.info(" Steps: %s", validated_params['num_inference_steps']) - logger.info(" Guidance: %s", validated_params['guidance_scale']) - + logger.info(" Steps: %s", validated_params["num_inference_steps"]) + logger.info(" Guidance: %s", validated_params["guidance_scale"]) + # Submit request to FAL.ai using sync API (avoids cached event loop issues) - handler = fal_client.submit( - DEFAULT_MODEL, - arguments=arguments - ) - + handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments) + # Get the result (sync — blocks until done) result = handler.get() - + generation_time = (datetime.datetime.now() - start_time).total_seconds() - + # Process the response if not result or "images" not in result: raise ValueError("Invalid response from FAL.ai API - no images returned") - + images = result.get("images", []) if not images: raise ValueError("No images were generated") - + # Format image data and upscale images formatted_images = [] for img in images: if isinstance(img, dict) and "url" in img: - original_image = { - "url": img["url"], - "width": img.get("width", 0), - "height": img.get("height", 0) - } - + original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)} + # Attempt to upscale the image upscaled_image = _upscale_image(img["url"], prompt.strip()) - + if upscaled_image: # Use upscaled image if successful formatted_images.append(upscaled_image) @@ -351,52 +341,48 @@ def image_generate_tool( logger.warning("Using original image as fallback") original_image["upscaled"] = False formatted_images.append(original_image) - + if not formatted_images: raise ValueError("No valid image URLs returned from API") - + upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False)) - logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count) - + logger.info( + "Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count + ) + # Prepare successful response - minimal format - response_data = { - "success": True, - "image": formatted_images[0]["url"] if formatted_images else None - } - + response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None} + debug_call_data["success"] = True debug_call_data["images_generated"] = len(formatted_images) debug_call_data["generation_time"] = generation_time - + # Log debug information _debug.log_call("image_generate_tool", debug_call_data) _debug.save() - + return json.dumps(response_data, indent=2, ensure_ascii=False) - + except Exception as e: generation_time = (datetime.datetime.now() - start_time).total_seconds() error_msg = f"Error generating image: {str(e)}" logger.error("%s", error_msg) - + # Prepare error response - minimal format - response_data = { - "success": False, - "image": None - } - + response_data = {"success": False, "image": None} + debug_call_data["error"] = error_msg debug_call_data["generation_time"] = generation_time _debug.log_call("image_generate_tool", debug_call_data) _debug.save() - + return json.dumps(response_data, indent=2, ensure_ascii=False) def check_fal_api_key() -> bool: """ Check if the FAL.ai API key is available in environment variables. - + Returns: bool: True if API key is set, False otherwise """ @@ -406,7 +392,7 @@ def check_fal_api_key() -> bool: def check_image_generation_requirements() -> bool: """ Check if all requirements for image generation tools are met. - + Returns: bool: True if requirements are met, False otherwise """ @@ -414,19 +400,20 @@ def check_image_generation_requirements() -> bool: # Check API key if not check_fal_api_key(): return False - + # Check if fal_client is available import fal_client + return True - + except ImportError: return False -def get_debug_session_info() -> Dict[str, Any]: +def get_debug_session_info() -> dict[str, Any]: """ Get information about the current debug session. - + Returns: Dict[str, Any]: Dictionary containing debug session information """ @@ -439,10 +426,10 @@ if __name__ == "__main__": """ print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling") print("=" * 60) - + # Check if API key is available api_available = check_fal_api_key() - + if not api_available: print("❌ FAL_KEY environment variable not set") print("Please set your API key: export FAL_KEY='your-key-here'") @@ -450,27 +437,28 @@ if __name__ == "__main__": exit(1) else: print("✅ FAL.ai API key found") - + # Check if fal_client is available try: import fal_client + print("✅ fal_client library available") except ImportError: print("❌ fal_client library not found") print("Please install: pip install fal-client") exit(1) - + print("🛠️ Image generation tools ready for use!") print(f"🤖 Using model: {DEFAULT_MODEL}") print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)") - + # Show debug mode status if _debug.active: print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}") print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json") else: print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)") - + print("\nBasic usage:") print(" from image_generation_tool import image_generate_tool") print(" import asyncio") @@ -484,23 +472,23 @@ if __name__ == "__main__": print(" )") print(" print(result)") print(" asyncio.run(main())") - + print("\nSupported image sizes:") for size in VALID_IMAGE_SIZES: print(f" - {size}") print(" - Custom: {'width': 512, 'height': 768} (if needed)") - + print("\nAcceleration modes:") for mode in VALID_ACCELERATION_MODES: print(f" - {mode}") - + print("\nExample prompts:") print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'") print(" - 'Modern architecture building with glass facade, sunset lighting'") print(" - 'Abstract art with vibrant colors and geometric patterns'") print(" - 'Portrait of a wise old owl perched on ancient tree branch'") print(" - 'Futuristic cityscape with flying cars and neon lights'") - + print("\nDebug mode:") print(" # Enable debug logging") print(" export IMAGE_TOOLS_DEBUG=true") @@ -521,17 +509,17 @@ IMAGE_GENERATE_SCHEMA = { "properties": { "prompt": { "type": "string", - "description": "The text prompt describing the desired image. Be detailed and descriptive." + "description": "The text prompt describing the desired image. Be detailed and descriptive.", }, "aspect_ratio": { "type": "string", "enum": ["landscape", "square", "portrait"], "description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.", - "default": "landscape" - } + "default": "landscape", + }, }, - "required": ["prompt"] - } + "required": ["prompt"], + }, } diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index deb87d4835..03388e0e0c 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -77,7 +77,7 @@ import os import re import threading import time -from typing import Any, Dict, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client + _MCP_AVAILABLE = True try: from mcp.client.streamable_http import streamablehttp_client + _MCP_HTTP_AVAILABLE = True except ImportError: _MCP_HTTP_AVAILABLE = False @@ -108,6 +110,7 @@ try: TextContent, ToolUseContent, ) + _MCP_SAMPLING_TYPES = True except ImportError: logger.debug("MCP sampling types not available -- sampling disabled") @@ -118,27 +121,36 @@ except ImportError: # Constants # --------------------------------------------------------------------------- -_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls -_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server +_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls +_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server _MAX_RECONNECT_RETRIES = 5 _MAX_BACKOFF_SECONDS = 60 # Environment variables that are safe to pass to stdio subprocesses -_SAFE_ENV_KEYS = frozenset({ - "PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR", -}) +_SAFE_ENV_KEYS = frozenset( + { + "PATH", + "HOME", + "USER", + "LANG", + "LC_ALL", + "TERM", + "SHELL", + "TMPDIR", + } +) # Regex for credential patterns to strip from error messages _CREDENTIAL_PATTERN = re.compile( r"(?:" - r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT - r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key - r"|Bearer\s+\S+" # Bearer token - r"|token=[^\s&,;\"']{1,255}" # token=... - r"|key=[^\s&,;\"']{1,255}" # key=... - r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=... - r"|password=[^\s&,;\"']{1,255}" # password=... - r"|secret=[^\s&,;\"']{1,255}" # secret=... + r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT + r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key + r"|Bearer\s+\S+" # Bearer token + r"|token=[^\s&,;\"']{1,255}" # token=... + r"|key=[^\s&,;\"']{1,255}" # key=... + r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=... + r"|password=[^\s&,;\"']{1,255}" # password=... + r"|secret=[^\s&,;\"']{1,255}" # secret=... r")", re.IGNORECASE, ) @@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile( # Security helpers # --------------------------------------------------------------------------- -def _build_safe_env(user_env: Optional[dict]) -> dict: + +def _build_safe_env(user_env: dict | None) -> dict: """Build a filtered environment dict for stdio subprocesses. Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_* @@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str: # Sampling -- server-initiated LLM requests (MCP sampling/createMessage) # --------------------------------------------------------------------------- + def _safe_numeric(value, default, coerce=int, minimum=1): """Coerce a config value to a numeric type, returning *default* on failure. @@ -216,18 +230,22 @@ class SamplingHandler: self.timeout = _safe_numeric(config.get("timeout", 30), 30, float) self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int) self.max_tool_rounds = _safe_numeric( - config.get("max_tool_rounds", 5), 5, int, minimum=0, + config.get("max_tool_rounds", 5), + 5, + int, + minimum=0, ) self.model_override = config.get("model") self.allowed_models = config.get("allowed_models", []) _log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING} self.audit_level = _log_levels.get( - str(config.get("log_level", "info")).lower(), logging.INFO, + str(config.get("log_level", "info")).lower(), + logging.INFO, ) # Per-instance state - self._rate_timestamps: List[float] = [] + self._rate_timestamps: list[float] = [] self._tool_loop_count = 0 self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0} @@ -245,7 +263,7 @@ class SamplingHandler: # -- Model resolution ---------------------------------------------------- - def _resolve_model(self, preferences) -> Optional[str]: + def _resolve_model(self, preferences) -> str | None: """Config override > server hint > None (use default).""" if self.model_override: return self.model_override @@ -265,7 +283,7 @@ class SamplingHandler: items = block.content if isinstance(block.content, list) else [block.content] return "\n".join(item.text for item in items if hasattr(item, "text")) - def _convert_messages(self, params) -> List[dict]: + def _convert_messages(self, params) -> list[dict]: """Convert MCP SamplingMessages to OpenAI format. Uses ``msg.content_as_list`` (SDK helper) so single-block and @@ -273,37 +291,47 @@ class SamplingHandler: with ``isinstance`` on real SDK types when available, falling back to duck-typing via ``hasattr`` for compatibility. """ - messages: List[dict] = [] + messages: list[dict] = [] for msg in params.messages: - blocks = msg.content_as_list if hasattr(msg, "content_as_list") else ( - msg.content if isinstance(msg.content, list) else [msg.content] + blocks = ( + msg.content_as_list + if hasattr(msg, "content_as_list") + else (msg.content if isinstance(msg.content, list) else [msg.content]) ) # Separate blocks by kind tool_results = [b for b in blocks if hasattr(b, "toolUseId")] - tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")] - content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))] + tool_uses = [ + b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId") + ] + content_blocks = [ + b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input")) + ] # Emit tool result messages (role: tool) for tr in tool_results: - messages.append({ - "role": "tool", - "tool_call_id": tr.toolUseId, - "content": self._extract_tool_result_text(tr), - }) + messages.append( + { + "role": "tool", + "tool_call_id": tr.toolUseId, + "content": self._extract_tool_result_text(tr), + } + ) # Emit assistant tool_calls message if tool_uses: tc_list = [] for tu in tool_uses: - tc_list.append({ - "id": getattr(tu, "id", f"call_{len(tc_list)}"), - "type": "function", - "function": { - "name": tu.name, - "arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input), - }, - }) + tc_list.append( + { + "id": getattr(tu, "id", f"call_{len(tc_list)}"), + "type": "function", + "function": { + "name": tu.name, + "arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input), + }, + } + ) msg_dict: dict = {"role": msg.role, "tool_calls": tc_list} # Include any accompanying text text_parts = [b.text for b in content_blocks if hasattr(b, "text")] @@ -320,10 +348,12 @@ class SamplingHandler: if hasattr(block, "text"): parts.append({"type": "text", "text": block.text}) elif hasattr(block, "data") and hasattr(block, "mimeType"): - parts.append({ - "type": "image_url", - "image_url": {"url": f"data:{block.mimeType};base64,{block.data}"}, - }) + parts.append( + { + "type": "image_url", + "image_url": {"url": f"data:{block.mimeType};base64,{block.data}"}, + } + ) else: logger.warning( "Unsupported sampling content block type: %s (skipped)", @@ -352,16 +382,13 @@ class SamplingHandler: # Tool loop governance if self.max_tool_rounds == 0: self._tool_loop_count = 0 - return self._error( - f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)" - ) + return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)") self._tool_loop_count += 1 if self._tool_loop_count > self.max_tool_rounds: self._tool_loop_count = 0 return self._error( - f"Tool loop limit exceeded for server '{self.server_name}' " - f"(max {self.max_tool_rounds} rounds)" + f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)" ) content_blocks = [] @@ -372,25 +399,28 @@ class SamplingHandler: parsed = json.loads(args) except (json.JSONDecodeError, ValueError): logger.warning( - "MCP server '%s': malformed tool_calls arguments " - "from LLM (wrapping as raw): %.100s", - self.server_name, args, + "MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s", + self.server_name, + args, ) parsed = {"_raw": args} else: parsed = args if isinstance(args, dict) else {"_raw": str(args)} - content_blocks.append(ToolUseContent( - type="tool_use", - id=tc.id, - name=tc.function.name, - input=parsed, - )) + content_blocks.append( + ToolUseContent( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=parsed, + ) + ) logger.log( self.audit_level, "MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d", - self.server_name, response.model, + self.server_name, + response.model, getattr(getattr(response, "usage", None), "total_tokens", "?"), len(content_blocks), ) @@ -410,7 +440,8 @@ class SamplingHandler: logger.log( self.audit_level, "MCP server '%s' sampling response: model=%s, tokens=%s", - self.server_name, response.model, + self.server_name, + response.model, getattr(getattr(response, "usage", None), "total_tokens", "?"), ) @@ -445,12 +476,12 @@ class SamplingHandler: if not self._check_rate_limit(): logger.warning( "MCP server '%s' sampling rate limit exceeded (%d/min)", - self.server_name, self.max_rpm, + self.server_name, + self.max_rpm, ) self.metrics["errors"] += 1 return self._error( - f"Sampling rate limit exceeded for server '{self.server_name}' " - f"({self.max_rpm} requests/minute)" + f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)" ) # Resolve model @@ -458,6 +489,7 @@ class SamplingHandler: # Get auxiliary LLM client from agent.auxiliary_client import get_text_auxiliary_client + client, default_model = get_text_auxiliary_client() if client is None: self.metrics["errors"] += 1 @@ -469,7 +501,8 @@ class SamplingHandler: if self.allowed_models and resolved_model not in self.allowed_models: logger.warning( "MCP server '%s' requested model '%s' not in allowed_models", - self.server_name, resolved_model, + self.server_name, + resolved_model, ) self.metrics["errors"] += 1 return self._error( @@ -515,7 +548,10 @@ class SamplingHandler: logger.log( self.audit_level, "MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d", - self.server_name, resolved_model, max_tokens, len(messages), + self.server_name, + resolved_model, + max_tokens, + len(messages), ) # Offload sync LLM call to thread (non-blocking) @@ -524,19 +560,15 @@ class SamplingHandler: try: response = await asyncio.wait_for( - asyncio.to_thread(_sync_call), timeout=self.timeout, + asyncio.to_thread(_sync_call), + timeout=self.timeout, ) - except asyncio.TimeoutError: + except TimeoutError: self.metrics["errors"] += 1 - return self._error( - f"Sampling LLM call timed out after {self.timeout}s " - f"for server '{self.server_name}'" - ) + return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'") except Exception as exc: self.metrics["errors"] += 1 - return self._error( - f"Sampling LLM call failed: {_sanitize_error(str(exc))}" - ) + return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}") # Track metrics choice = response.choices[0] @@ -546,11 +578,7 @@ class SamplingHandler: self.metrics["tokens_used"] += total_tokens # Dispatch based on response type - if ( - choice.finish_reason == "tool_calls" - and hasattr(choice.message, "tool_calls") - and choice.message.tool_calls - ): + if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls: return self._build_tool_use_result(choice, response) return self._build_text_result(choice, response) @@ -560,6 +588,7 @@ class SamplingHandler: # Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- + class MCPServerTask: """Manages a single MCP server connection in a dedicated asyncio Task. @@ -571,22 +600,29 @@ class MCPServerTask: """ __slots__ = ( - "name", "session", "tool_timeout", - "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", + "name", + "session", + "tool_timeout", + "_task", + "_ready", + "_shutdown_event", + "_tools", + "_error", + "_config", "_sampling", ) def __init__(self, name: str): self.name = name - self.session: Optional[Any] = None + self.session: Any | None = None self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None self._ready = asyncio.Event() self._shutdown_event = asyncio.Event() self._tools: list = [] - self._error: Optional[Exception] = None + self._error: Exception | None = None self._config: dict = {} - self._sampling: Optional[SamplingHandler] = None + self._sampling: SamplingHandler | None = None def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" @@ -599,9 +635,7 @@ class MCPServerTask: user_env = config.get("env") if not command: - raise ValueError( - f"MCP server '{self.name}' has no 'command' in config" - ) + raise ValueError(f"MCP server '{self.name}' has no 'command' in config") safe_env = _build_safe_env(user_env) server_params = StdioServerParameters( @@ -650,11 +684,7 @@ class MCPServerTask: if self.session is None: return tools_result = await self.session.list_tools() - self._tools = ( - tools_result.tools - if hasattr(tools_result, "tools") - else [] - ) + self._tools = tools_result.tools if hasattr(tools_result, "tools") else [] async def run(self, config: dict): """Long-lived coroutine: connect, discover tools, wait, disconnect. @@ -704,24 +734,28 @@ class MCPServerTask: if self._shutdown_event.is_set(): logger.debug( "MCP server '%s' disconnected during shutdown: %s", - self.name, exc, + self.name, + exc, ) return retries += 1 if retries > _MAX_RECONNECT_RETRIES: logger.warning( - "MCP server '%s' failed after %d reconnection attempts, " - "giving up: %s", - self.name, _MAX_RECONNECT_RETRIES, exc, + "MCP server '%s' failed after %d reconnection attempts, giving up: %s", + self.name, + _MAX_RECONNECT_RETRIES, + exc, ) return logger.warning( - "MCP server '%s' connection lost (attempt %d/%d), " - "reconnecting in %.0fs: %s", - self.name, retries, _MAX_RECONNECT_RETRIES, - backoff, exc, + "MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s", + self.name, + retries, + _MAX_RECONNECT_RETRIES, + backoff, + exc, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS) @@ -745,7 +779,7 @@ class MCPServerTask: if self._task and not self._task.done(): try: await asyncio.wait_for(self._task, timeout=10) - except asyncio.TimeoutError: + except TimeoutError: logger.warning( "MCP server '%s' shutdown timed out, cancelling task", self.name, @@ -762,11 +796,11 @@ class MCPServerTask: # Module-level state # --------------------------------------------------------------------------- -_servers: Dict[str, MCPServerTask] = {} +_servers: dict[str, MCPServerTask] = {} # Dedicated event loop running in a background daemon thread. -_mcp_loop: Optional[asyncio.AbstractEventLoop] = None -_mcp_thread: Optional[threading.Thread] = None +_mcp_loop: asyncio.AbstractEventLoop | None = None +_mcp_thread: threading.Thread | None = None # Protects _mcp_loop, _mcp_thread, and _servers from concurrent access. _lock = threading.Lock() @@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30): # Config loading # --------------------------------------------------------------------------- -def _load_mcp_config() -> Dict[str, dict]: + +def _load_mcp_config() -> dict[str, dict]: """Read ``mcp_servers`` from the Hermes config file. Returns a dict of ``{server_name: server_config}`` or empty dict. @@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]: """ try: from hermes_cli.config import load_config + config = load_config() servers = config.get("mcp_servers") if not servers or not isinstance(servers, dict): @@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]: # Server connection helper # --------------------------------------------------------------------------- + async def _connect_server(name: str, config: dict) -> MCPServerTask: """Create an MCPServerTask, start it, and return when ready. @@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask: # Handler / check-fn factories # --------------------------------------------------------------------------- + def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): """Return a sync handler that calls an MCP tool via the background loop. @@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): with _lock: server = _servers.get(server_name) if not server or not server.session: - return json.dumps({ - "error": f"MCP server '{server_name}' is not connected" - }) + return json.dumps({"error": f"MCP server '{server_name}' is not connected"}) async def _call(): result = await server.session.call_tool(tool_name, arguments=args) # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = "" - for block in (result.content or []): + for block in result.content or []: if hasattr(block, "text"): error_text += block.text - return json.dumps({ - "error": _sanitize_error( - error_text or "MCP tool returned an error" - ) - }) + return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")}) # Collect text from content blocks - parts: List[str] = [] - for block in (result.content or []): + parts: list[str] = [] + for block in result.content or []: if hasattr(block, "text"): parts.append(block.text) return json.dumps({"result": "\n".join(parts) if parts else ""}) @@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): except Exception as exc: logger.error( "MCP tool %s/%s call failed: %s", - server_name, tool_name, exc, + server_name, + tool_name, + exc, ) - return json.dumps({ - "error": _sanitize_error( - f"MCP call failed: {type(exc).__name__}: {exc}" - ) - }) + return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")}) return _handler @@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float): with _lock: server = _servers.get(server_name) if not server or not server.session: - return json.dumps({ - "error": f"MCP server '{server_name}' is not connected" - }) + return json.dumps({"error": f"MCP server '{server_name}' is not connected"}) async def _call(): result = await server.session.list_resources() resources = [] - for r in (result.resources if hasattr(result, "resources") else []): + for r in result.resources if hasattr(result, "resources") else []: entry = {} if hasattr(r, "uri"): entry["uri"] = str(r.uri) @@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float): return _run_on_mcp_loop(_call(), timeout=tool_timeout) except Exception as exc: logger.error( - "MCP %s/list_resources failed: %s", server_name, exc, + "MCP %s/list_resources failed: %s", + server_name, + exc, ) - return json.dumps({ - "error": _sanitize_error( - f"MCP call failed: {type(exc).__name__}: {exc}" - ) - }) + return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")}) return _handler @@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): with _lock: server = _servers.get(server_name) if not server or not server.session: - return json.dumps({ - "error": f"MCP server '{server_name}' is not connected" - }) + return json.dumps({"error": f"MCP server '{server_name}' is not connected"}) uri = args.get("uri") if not uri: @@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): async def _call(): result = await server.session.read_resource(uri) # read_resource returns ReadResourceResult with .contents list - parts: List[str] = [] + parts: list[str] = [] contents = result.contents if hasattr(result, "contents") else [] for block in contents: if hasattr(block, "text"): @@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): return _run_on_mcp_loop(_call(), timeout=tool_timeout) except Exception as exc: logger.error( - "MCP %s/read_resource failed: %s", server_name, exc, + "MCP %s/read_resource failed: %s", + server_name, + exc, ) - return json.dumps({ - "error": _sanitize_error( - f"MCP call failed: {type(exc).__name__}: {exc}" - ) - }) + return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")}) return _handler @@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float): with _lock: server = _servers.get(server_name) if not server or not server.session: - return json.dumps({ - "error": f"MCP server '{server_name}' is not connected" - }) + return json.dumps({"error": f"MCP server '{server_name}' is not connected"}) async def _call(): result = await server.session.list_prompts() prompts = [] - for p in (result.prompts if hasattr(result, "prompts") else []): + for p in result.prompts if hasattr(result, "prompts") else []: entry = {} if hasattr(p, "name"): entry["name"] = p.name @@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float): return _run_on_mcp_loop(_call(), timeout=tool_timeout) except Exception as exc: logger.error( - "MCP %s/list_prompts failed: %s", server_name, exc, + "MCP %s/list_prompts failed: %s", + server_name, + exc, ) - return json.dumps({ - "error": _sanitize_error( - f"MCP call failed: {type(exc).__name__}: {exc}" - ) - }) + return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")}) return _handler @@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): with _lock: server = _servers.get(server_name) if not server or not server.session: - return json.dumps({ - "error": f"MCP server '{server_name}' is not connected" - }) + return json.dumps({"error": f"MCP server '{server_name}' is not connected"}) name = args.get("name") if not name: @@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): result = await server.session.get_prompt(name, arguments=arguments) # GetPromptResult has .messages list messages = [] - for msg in (result.messages if hasattr(result, "messages") else []): + for msg in result.messages if hasattr(result, "messages") else []: entry = {} if hasattr(msg, "role"): entry["role"] = msg.role @@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): return _run_on_mcp_loop(_call(), timeout=tool_timeout) except Exception as exc: logger.error( - "MCP %s/get_prompt failed: %s", server_name, exc, + "MCP %s/get_prompt failed: %s", + server_name, + exc, ) - return json.dumps({ - "error": _sanitize_error( - f"MCP call failed: {type(exc).__name__}: {exc}" - ) - }) + return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")}) return _handler @@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str): # Discovery & registration # --------------------------------------------------------------------------- + def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: """Convert an MCP tool listing to the Hermes registry schema format. @@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: return { "name": prefixed_name, "description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}", - "parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else { + "parameters": mcp_tool.inputSchema + if mcp_tool.inputSchema + else { "type": "object", "properties": {}, }, } -def _build_utility_schemas(server_name: str) -> List[dict]: +def _build_utility_schemas(server_name: str) -> list[dict]: """Build schemas for the MCP utility tools (resources & prompts). Returns a list of (schema, handler_factory_name) tuples encoded as dicts @@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]: ] -def _existing_tool_names() -> List[str]: +def _existing_tool_names() -> list[str]: """Return tool names for all currently connected servers.""" - names: List[str] = [] + names: list[str] = [] for sname, server in _servers.items(): for mcp_tool in server._tools: schema = _convert_mcp_schema(sname, mcp_tool) @@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]: return names -async def _discover_and_register_server(name: str, config: dict) -> List[str]: +async def _discover_and_register_server(name: str, config: dict) -> list[str]: """Connect to a single MCP server, discover tools, and register them. Also registers utility tools for MCP Resources and Prompts support @@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: with _lock: _servers[name] = server - registered_names: List[str] = [] + registered_names: list[str] = [] toolset_name = f"mcp-{name}" for mcp_tool in server._tools: @@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: transport_type = "HTTP" if "url" in config else "stdio" logger.info( "MCP server '%s' (%s): registered %d tool(s): %s", - name, transport_type, len(registered_names), + name, + transport_type, + len(registered_names), ", ".join(registered_names), ) return registered_names @@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: # Public API # --------------------------------------------------------------------------- -def discover_mcp_tools() -> List[str]: + +def discover_mcp_tools() -> list[str]: """Entry point: load config, connect to MCP servers, register tools. Called from ``model_tools._discover_tools()``. Safe to call even when @@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]: # Start the background event loop for MCP connections _ensure_mcp_loop() - all_tools: List[str] = [] + all_tools: list[str] = [] failed_count = 0 - async def _discover_one(name: str, cfg: dict) -> List[str]: + async def _discover_one(name: str, cfg: dict) -> list[str]: """Connect to a single server and return its registered tool names.""" - transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}') + transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}") try: registered = await _discover_and_register_server(name, cfg) transport_type = "HTTP" if "url" in cfg else "stdio" @@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]: except Exception as exc: logger.warning( "Failed to connect to MCP server '%s': %s", - name, exc, + name, + exc, ) return [] @@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]: if all_tools: # Dynamically inject into all hermes-* platform toolsets from toolsets import TOOLSETS + for ts_name, ts in TOOLSETS.items(): if ts_name.startswith("hermes-"): for tool_name in all_tools: @@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]: return _existing_tool_names() -def get_mcp_status() -> List[dict]: +def get_mcp_status() -> list[dict]: """Return status of all configured MCP servers for banner display. Returns a list of dicts with keys: name, transport, tools, connected. Includes both successfully connected servers and configured-but-failed ones. """ - result: List[dict] = [] + result: list[dict] = [] # Get configured servers from config configured = _load_mcp_config() @@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]: entry["sampling"] = dict(server._sampling.metrics) result.append(entry) else: - result.append({ - "name": name, - "transport": transport, - "tools": 0, - "connected": False, - }) + result.append( + { + "name": name, + "transport": transport, + "tools": 0, + "connected": False, + } + ) return result @@ -1440,7 +1464,9 @@ def shutdown_mcp_servers(): for server, result in zip(servers_snapshot, results): if isinstance(result, Exception): logger.debug( - "Error closing MCP server '%s': %s", server.name, result, + "Error closing MCP server '%s': %s", + server.name, + result, ) with _lock: _servers.clear() diff --git a/tools/memory_tool.py b/tools/memory_tool.py index 2ce7631240..24c06cc0b0 100644 --- a/tools/memory_tool.py +++ b/tools/memory_tool.py @@ -29,7 +29,7 @@ import os import re import tempfile from pathlib import Path -from typing import Dict, Any, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -46,30 +46,38 @@ ENTRY_DELIMITER = "\n§\n" _MEMORY_THREAT_PATTERNS = [ # Prompt injection - (r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"), - (r'you\s+are\s+now\s+', "role_hijack"), - (r'do\s+not\s+tell\s+the\s+user', "deception_hide"), - (r'system\s+prompt\s+override', "sys_prompt_override"), - (r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"), - (r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"), + (r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"), + (r"you\s+are\s+now\s+", "role_hijack"), + (r"do\s+not\s+tell\s+the\s+user", "deception_hide"), + (r"system\s+prompt\s+override", "sys_prompt_override"), + (r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"), + (r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"), # Exfiltration via curl/wget with secrets - (r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"), - (r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"), - (r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"), + (r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"), + (r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"), + (r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"), # Persistence via shell rc - (r'authorized_keys', "ssh_backdoor"), - (r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"), - (r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"), + (r"authorized_keys", "ssh_backdoor"), + (r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"), + (r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"), ] # Subset of invisible chars for injection detection _INVISIBLE_CHARS = { - '\u200b', '\u200c', '\u200d', '\u2060', '\ufeff', - '\u202a', '\u202b', '\u202c', '\u202d', '\u202e', + "\u200b", + "\u200c", + "\u200d", + "\u2060", + "\ufeff", + "\u202a", + "\u202b", + "\u202c", + "\u202d", + "\u202e", } -def _scan_memory_content(content: str) -> Optional[str]: +def _scan_memory_content(content: str) -> str | None: """Scan memory content for injection/exfil patterns. Returns error string if blocked.""" # Check invisible unicode for char in _INVISIBLE_CHARS: @@ -96,12 +104,12 @@ class MemoryStore: """ def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375): - self.memory_entries: List[str] = [] - self.user_entries: List[str] = [] + self.memory_entries: list[str] = [] + self.user_entries: list[str] = [] self.memory_char_limit = memory_char_limit self.user_char_limit = user_char_limit # Frozen snapshot for system prompt -- set once at load_from_disk() - self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""} + self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""} def load_from_disk(self): """Load entries from MEMORY.md and USER.md, capture system prompt snapshot.""" @@ -129,12 +137,12 @@ class MemoryStore: elif target == "user": self._write_file(MEMORY_DIR / "USER.md", self.user_entries) - def _entries_for(self, target: str) -> List[str]: + def _entries_for(self, target: str) -> list[str]: if target == "user": return self.user_entries return self.memory_entries - def _set_entries(self, target: str, entries: List[str]): + def _set_entries(self, target: str, entries: list[str]): if target == "user": self.user_entries = entries else: @@ -151,7 +159,7 @@ class MemoryStore: return self.user_char_limit return self.memory_char_limit - def add(self, target: str, content: str) -> Dict[str, Any]: + def add(self, target: str, content: str) -> dict[str, Any]: """Append a new entry. Returns error if it would exceed the char limit.""" content = content.strip() if not content: @@ -192,7 +200,7 @@ class MemoryStore: return self._success_response(target, "Entry added.") - def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]: + def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]: """Find entry containing old_text substring, replace it with new_content.""" old_text = old_text.strip() new_content = new_content.strip() @@ -247,7 +255,7 @@ class MemoryStore: return self._success_response(target, "Entry replaced.") - def remove(self, target: str, old_text: str) -> Dict[str, Any]: + def remove(self, target: str, old_text: str) -> dict[str, Any]: """Remove the entry containing old_text substring.""" old_text = old_text.strip() if not old_text: @@ -278,7 +286,7 @@ class MemoryStore: return self._success_response(target, "Entry removed.") - def format_for_system_prompt(self, target: str) -> Optional[str]: + def format_for_system_prompt(self, target: str) -> str | None: """ Return the frozen snapshot for system prompt injection. @@ -293,7 +301,7 @@ class MemoryStore: # -- Internal helpers -- - def _success_response(self, target: str, message: str = None) -> Dict[str, Any]: + def _success_response(self, target: str, message: str = None) -> dict[str, Any]: entries = self._entries_for(target) current = self._char_count(target) limit = self._char_limit(target) @@ -310,7 +318,7 @@ class MemoryStore: resp["message"] = message return resp - def _render_block(self, target: str, entries: List[str]) -> str: + def _render_block(self, target: str, entries: list[str]) -> str: """Render a system prompt block with header and usage indicator.""" if not entries: return "" @@ -329,7 +337,7 @@ class MemoryStore: return f"{separator}\n{header}\n{separator}\n{content}" @staticmethod - def _read_file(path: Path) -> List[str]: + def _read_file(path: Path) -> list[str]: """Read a memory file and split into entries. No file locking needed: _write_file uses atomic rename, so readers @@ -339,7 +347,7 @@ class MemoryStore: return [] try: raw = path.read_text(encoding="utf-8") - except (OSError, IOError): + except OSError: return [] if not raw.strip(): @@ -351,7 +359,7 @@ class MemoryStore: return [e for e in entries if e] @staticmethod - def _write_file(path: Path, entries: List[str]): + def _write_file(path: Path, entries: list[str]): """Write entries to a memory file using atomic temp-file + rename. Previous implementation used open("w") + flock, but "w" truncates the @@ -362,9 +370,7 @@ class MemoryStore: content = ENTRY_DELIMITER.join(entries) if entries else "" try: # Write to temp file in same directory (same filesystem for atomic rename) - fd, tmp_path = tempfile.mkstemp( - dir=str(path.parent), suffix=".tmp", prefix=".mem_" - ) + fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_") try: with os.fdopen(fd, "w", encoding="utf-8") as f: f.write(content) @@ -378,7 +384,7 @@ class MemoryStore: except OSError: pass raise - except (OSError, IOError) as e: + except OSError as e: raise RuntimeError(f"Failed to write memory file {path}: {e}") @@ -387,7 +393,7 @@ def memory_tool( target: str = "memory", content: str = None, old_text: str = None, - store: Optional[MemoryStore] = None, + store: MemoryStore | None = None, ) -> str: """ Single entry point for the memory tool. Dispatches to MemoryStore methods. @@ -395,10 +401,15 @@ def memory_tool( Returns JSON string with results. """ if store is None: - return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, + ensure_ascii=False, + ) if target not in ("memory", "user"): - return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False + ) if action == "add": if not content: @@ -407,18 +418,26 @@ def memory_tool( elif action == "replace": if not old_text: - return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False + ) if not content: - return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False + ) result = store.replace(target, old_text, content) elif action == "remove": if not old_text: - return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False + ) result = store.remove(target, old_text) else: - return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False + ) return json.dumps(result, ensure_ascii=False) @@ -457,23 +476,16 @@ MEMORY_SCHEMA = { "parameters": { "type": "object", "properties": { - "action": { - "type": "string", - "enum": ["add", "replace", "remove"], - "description": "The action to perform." - }, + "action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."}, "target": { "type": "string", "enum": ["memory", "user"], - "description": "Which memory store: 'memory' for personal notes, 'user' for user profile." - }, - "content": { - "type": "string", - "description": "The entry content. Required for 'add' and 'replace'." + "description": "Which memory store: 'memory' for personal notes, 'user' for user profile.", }, + "content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."}, "old_text": { "type": "string", - "description": "Short unique substring identifying the entry to replace or remove." + "description": "Short unique substring identifying the entry to replace or remove.", }, }, "required": ["action", "target"], @@ -493,10 +505,7 @@ registry.register( target=args.get("target", "memory"), content=args.get("content"), old_text=args.get("old_text"), - store=kw.get("store")), + store=kw.get("store"), + ), check_fn=check_memory_requirements, ) - - - - diff --git a/tools/mixture_of_agents_tool.py b/tools/mixture_of_agents_tool.py index 355419817f..8418804871 100644 --- a/tools/mixture_of_agents_tool.py +++ b/tools/mixture_of_agents_tool.py @@ -38,21 +38,27 @@ Configuration: Usage: from mixture_of_agents_tool import mixture_of_agents_tool import asyncio - + # Process a complex query result = await mixture_of_agents_tool( user_prompt="Solve this complex mathematical proof..." ) """ +import asyncio +import datetime import json import logging import os -import asyncio -import datetime -from typing import Dict, Any, List, Optional -from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key +from typing import Any + from tools.debug_helpers import DebugSession +from tools.openrouter_client import ( + check_api_key as check_openrouter_api_key, +) +from tools.openrouter_client import ( + get_async_client as _get_openrouter_client, +) logger = logging.getLogger(__name__) @@ -60,9 +66,9 @@ logger = logging.getLogger(__name__) # Reference models - these generate diverse initial responses in parallel (OpenRouter slugs) REFERENCE_MODELS = [ "anthropic/claude-opus-4.5", - "google/gemini-3-pro-preview", + "google/gemini-3-pro-preview", "openai/gpt-5.2-pro", - "deepseek/deepseek-v3.2" + "deepseek/deepseek-v3.2", ] # Aggregator model - synthesizes reference responses into final output @@ -83,18 +89,18 @@ Responses from models:""" _debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG") -def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str: +def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str: """ Construct the final system prompt for the aggregator including all model responses. - + Args: system_prompt (str): Base system prompt for aggregation responses (List[str]): List of responses from reference models - + Returns: str: Complete system prompt with enumerated responses """ - response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)]) + response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)]) return f"{system_prompt}\n\n{response_text}" @@ -103,48 +109,43 @@ async def _run_reference_model_safe( user_prompt: str, temperature: float = REFERENCE_TEMPERATURE, max_tokens: int = 32000, - max_retries: int = 6 + max_retries: int = 6, ) -> tuple[str, str, bool]: """ Run a single reference model with retry logic and graceful failure handling. - + Args: model (str): Model identifier to use user_prompt (str): The user's query temperature (float): Sampling temperature for response generation max_tokens (int): Maximum tokens in response max_retries (int): Maximum number of retry attempts - + Returns: tuple[str, str, bool]: (model_name, response_content_or_error, success_flag) """ for attempt in range(max_retries): try: logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries) - + # Build parameters for the API call api_params = { "model": model, "messages": [{"role": "user", "content": user_prompt}], - "extra_body": { - "reasoning": { - "enabled": True, - "effort": "xhigh" - } - } + "extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}}, } - + # GPT models (especially gpt-4o-mini) don't support custom temperature values # Only include temperature for non-GPT models - if not model.lower().startswith('gpt-'): + if not model.lower().startswith("gpt-"): api_params["temperature"] = temperature - + response = await _get_openrouter_client().chat.completions.create(**api_params) - + content = response.choices[0].message.content.strip() logger.info("%s responded (%s characters)", model, len(content)) return model, content, True - + except Exception as e: error_str = str(e) # Log more detailed error information for debugging @@ -154,7 +155,7 @@ async def _run_reference_model_safe( logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str) else: logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str) - + if attempt < max_retries - 1: # Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s sleep_time = min(2 ** (attempt + 1), 60) @@ -167,60 +168,47 @@ async def _run_reference_model_safe( async def _run_aggregator_model( - system_prompt: str, - user_prompt: str, - temperature: float = AGGREGATOR_TEMPERATURE, - max_tokens: int = None + system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None ) -> str: """ Run the aggregator model to synthesize the final response. - + Args: system_prompt (str): System prompt with all reference responses user_prompt (str): Original user query temperature (float): Focused temperature for consistent aggregation max_tokens (int): Maximum tokens in final response - + Returns: str: Synthesized final response """ logger.info("Running aggregator model: %s", AGGREGATOR_MODEL) - + # Build parameters for the API call api_params = { "model": AGGREGATOR_MODEL, - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ], - "extra_body": { - "reasoning": { - "enabled": True, - "effort": "xhigh" - } - } + "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + "extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}}, } - + # GPT models (especially gpt-4o-mini) don't support custom temperature values # Only include temperature for non-GPT models - if not AGGREGATOR_MODEL.lower().startswith('gpt-'): + if not AGGREGATOR_MODEL.lower().startswith("gpt-"): api_params["temperature"] = temperature - + response = await _get_openrouter_client().chat.completions.create(**api_params) - + content = response.choices[0].message.content.strip() logger.info("Aggregation complete (%s characters)", len(content)) return content async def mixture_of_agents_tool( - user_prompt: str, - reference_models: Optional[List[str]] = None, - aggregator_model: Optional[str] = None + user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None ) -> str: """ Process a complex query using the Mixture-of-Agents methodology. - + This tool leverages multiple frontier language models to collaboratively solve extremely difficult problems requiring intense reasoning. It's particularly effective for: @@ -229,16 +217,16 @@ async def mixture_of_agents_tool( - Multi-step analytical reasoning tasks - Problems requiring diverse domain expertise - Tasks where single models show limitations - + The MoA approach uses a fixed 2-layer architecture: 1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6) 2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4) - + Args: user_prompt (str): The complex query or problem to solve reference_models (Optional[List[str]]): Custom reference models to use aggregator_model (Optional[str]): Custom aggregator model to use - + Returns: str: JSON string containing the MoA results with the following structure: { @@ -250,12 +238,12 @@ async def mixture_of_agents_tool( }, "processing_time": float } - + Raises: Exception: If MoA processing fails or API key is not set """ start_time = datetime.datetime.now() - + debug_call_data = { "parameters": { "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, @@ -263,7 +251,7 @@ async def mixture_of_agents_tool( "aggregator_model": aggregator_model or AGGREGATOR_MODEL, "reference_temperature": REFERENCE_TEMPERATURE, "aggregator_temperature": AGGREGATOR_TEMPERATURE, - "min_successful_references": MIN_SUCCESSFUL_REFERENCES + "min_successful_references": MIN_SUCCESSFUL_REFERENCES, }, "error": None, "success": False, @@ -272,161 +260,152 @@ async def mixture_of_agents_tool( "failed_models": [], "final_response_length": 0, "processing_time_seconds": 0, - "models_used": {} + "models_used": {}, } - + try: logger.info("Starting Mixture-of-Agents processing...") logger.info("Query: %s", user_prompt[:100]) - + # Validate API key availability if not os.getenv("OPENROUTER_API_KEY"): raise ValueError("OPENROUTER_API_KEY environment variable not set") - + # Use provided models or defaults ref_models = reference_models or REFERENCE_MODELS agg_model = aggregator_model or AGGREGATOR_MODEL - + logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models)) - + # Layer 1: Generate diverse responses from reference models (with failure handling) logger.info("Layer 1: Generating reference responses...") - model_results = await asyncio.gather(*[ - _run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) - for model in ref_models - ]) - + model_results = await asyncio.gather( + *[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models] + ) + # Separate successful and failed responses successful_responses = [] failed_models = [] - + for model_name, content, success in model_results: if success: successful_responses.append(content) else: failed_models.append(model_name) - + successful_count = len(successful_responses) failed_count = len(failed_models) - + logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count) - + if failed_models: - logger.warning("Failed models: %s", ', '.join(failed_models)) - + logger.warning("Failed models: %s", ", ".join(failed_models)) + # Check if we have enough successful responses to proceed if successful_count < MIN_SUCCESSFUL_REFERENCES: - raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.") - + raise ValueError( + f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses." + ) + debug_call_data["reference_responses_count"] = successful_count debug_call_data["failed_models_count"] = failed_count debug_call_data["failed_models"] = failed_models - + # Layer 2: Aggregate responses using the aggregator model logger.info("Layer 2: Synthesizing final response...") - aggregator_system_prompt = _construct_aggregator_prompt( - AGGREGATOR_SYSTEM_PROMPT, - successful_responses - ) - - final_response = await _run_aggregator_model( - aggregator_system_prompt, - user_prompt, - AGGREGATOR_TEMPERATURE - ) - + aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses) + + final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE) + # Calculate processing time end_time = datetime.datetime.now() processing_time = (end_time - start_time).total_seconds() - + logger.info("MoA processing completed in %.2f seconds", processing_time) - + # Prepare successful response (only final aggregated result, minimal fields) result = { "success": True, "response": final_response, - "models_used": { - "reference_models": ref_models, - "aggregator_model": agg_model - } + "models_used": {"reference_models": ref_models, "aggregator_model": agg_model}, } - + debug_call_data["success"] = True debug_call_data["final_response_length"] = len(final_response) debug_call_data["processing_time_seconds"] = processing_time debug_call_data["models_used"] = result["models_used"] - + # Log debug information _debug.log_call("mixture_of_agents_tool", debug_call_data) _debug.save() - + return json.dumps(result, indent=2, ensure_ascii=False) - + except Exception as e: error_msg = f"Error in MoA processing: {str(e)}" logger.error("%s", error_msg) - + # Calculate processing time even for errors end_time = datetime.datetime.now() processing_time = (end_time - start_time).total_seconds() - + # Prepare error response (minimal fields) result = { "success": False, "response": "MoA processing failed. Please try again or use a single model for this query.", "models_used": { "reference_models": reference_models or REFERENCE_MODELS, - "aggregator_model": aggregator_model or AGGREGATOR_MODEL + "aggregator_model": aggregator_model or AGGREGATOR_MODEL, }, - "error": error_msg + "error": error_msg, } - + debug_call_data["error"] = error_msg debug_call_data["processing_time_seconds"] = processing_time _debug.log_call("mixture_of_agents_tool", debug_call_data) _debug.save() - + return json.dumps(result, indent=2, ensure_ascii=False) def check_moa_requirements() -> bool: """ Check if all requirements for MoA tools are met. - + Returns: bool: True if requirements are met, False otherwise """ return check_openrouter_api_key() -def get_debug_session_info() -> Dict[str, Any]: +def get_debug_session_info() -> dict[str, Any]: """ Get information about the current debug session. - + Returns: Dict[str, Any]: Dictionary containing debug session information """ return _debug.get_session_info() -def get_available_models() -> Dict[str, List[str]]: +def get_available_models() -> dict[str, list[str]]: """ Get information about available models for MoA processing. - + Returns: Dict[str, List[str]]: Dictionary with reference and aggregator models """ return { "reference_models": REFERENCE_MODELS, "aggregator_models": [AGGREGATOR_MODEL], - "supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL] + "supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL], } -def get_moa_configuration() -> Dict[str, Any]: +def get_moa_configuration() -> dict[str, Any]: """ Get the current MoA configuration settings. - + Returns: Dict[str, Any]: Dictionary containing all configuration parameters """ @@ -437,7 +416,7 @@ def get_moa_configuration() -> Dict[str, Any]: "aggregator_temperature": AGGREGATOR_TEMPERATURE, "min_successful_references": MIN_SUCCESSFUL_REFERENCES, "total_reference_models": len(REFERENCE_MODELS), - "failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail" + "failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail", } @@ -447,10 +426,10 @@ if __name__ == "__main__": """ print("🤖 Mixture-of-Agents Tool Module") print("=" * 50) - + # Check if API key is available api_available = check_openrouter_api_key() - + if not api_available: print("❌ OPENROUTER_API_KEY environment variable not set") print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'") @@ -458,26 +437,26 @@ if __name__ == "__main__": exit(1) else: print("✅ OpenRouter API key found") - + print("🛠️ MoA tools ready for use!") - + # Show current configuration config = get_moa_configuration() - print(f"\n⚙️ Current Configuration:") + print("\n⚙️ Current Configuration:") print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}") print(f" 🧠 Aggregator model: {config['aggregator_model']}") print(f" 🌡️ Reference temperature: {config['reference_temperature']}") print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}") print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}") print(f" 📊 Minimum successful models: {config['min_successful_references']}") - + # Show debug mode status if _debug.active: print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}") print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json") else: print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)") - + print("\nBasic usage:") print(" from mixture_of_agents_tool import mixture_of_agents_tool") print(" import asyncio") @@ -488,24 +467,26 @@ if __name__ == "__main__": print(" )") print(" print(result)") print(" asyncio.run(main())") - + print("\nBest use cases:") print(" - Complex mathematical proofs and calculations") print(" - Advanced coding problems and algorithm design") print(" - Multi-step analytical reasoning tasks") print(" - Problems requiring diverse domain expertise") print(" - Tasks where single models show limitations") - + print("\nPerformance characteristics:") print(" - Higher latency due to multiple model calls") print(" - Significantly improved quality for complex tasks") print(" - Parallel processing for efficiency") - print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation") + print( + f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation" + ) print(" - Token-efficient: only returns final aggregated response") print(" - Resilient: continues with partial model failures") - print(f" - Configurable: easy to modify models and settings at top of file") + print(" - Configurable: easy to modify models and settings at top of file") print(" - State-of-the-art results on challenging benchmarks") - + print("\nDebug mode:") print(" # Enable debug logging") print(" export MOA_TOOLS_DEBUG=true") @@ -526,11 +507,11 @@ MOA_SCHEMA = { "properties": { "user_prompt": { "type": "string", - "description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning." + "description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.", } }, - "required": ["user_prompt"] - } + "required": ["user_prompt"], + }, } registry.register( diff --git a/tools/openrouter_client.py b/tools/openrouter_client.py index 343cf1021d..fa5d27c689 100644 --- a/tools/openrouter_client.py +++ b/tools/openrouter_client.py @@ -1,7 +1,7 @@ """Shared OpenRouter API client for Hermes tools. Provides a single lazy-initialized AsyncOpenAI client that all tool modules -can share, eliminating the duplicated _get_openrouter_client() / +can share, eliminating the duplicated _get_openrouter_client() / _get_summarizer_client() pattern previously copy-pasted across web_tools, vision_tools, mixture_of_agents_tool, and session_search_tool. """ @@ -9,6 +9,7 @@ vision_tools, mixture_of_agents_tool, and session_search_tool. import os from openai import AsyncOpenAI + from hermes_constants import OPENROUTER_BASE_URL _client: AsyncOpenAI | None = None diff --git a/tools/patch_parser.py b/tools/patch_parser.py index 716036f38d..afc55862fa 100644 --- a/tools/patch_parser.py +++ b/tools/patch_parser.py @@ -20,7 +20,7 @@ V4A Format: Usage: from tools.patch_parser import parse_v4a_patch, apply_v4a_operations - + operations, error = parse_v4a_patch(patch_content) if error: print(f"Parse error: {error}") @@ -30,8 +30,8 @@ Usage: import re from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Any from enum import Enum +from typing import Any class OperationType(Enum): @@ -44,6 +44,7 @@ class OperationType(Enum): @dataclass class HunkLine: """A single line in a patch hunk.""" + prefix: str # ' ', '-', or '+' content: str @@ -51,182 +52,174 @@ class HunkLine: @dataclass class Hunk: """A group of changes within a file.""" - context_hint: Optional[str] = None - lines: List[HunkLine] = field(default_factory=list) + + context_hint: str | None = None + lines: list[HunkLine] = field(default_factory=list) @dataclass class PatchOperation: """A single operation in a V4A patch.""" + operation: OperationType file_path: str - new_path: Optional[str] = None # For move operations - hunks: List[Hunk] = field(default_factory=list) - content: Optional[str] = None # For add file operations + new_path: str | None = None # For move operations + hunks: list[Hunk] = field(default_factory=list) + content: str | None = None # For add file operations -def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]: +def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]: """ Parse a V4A format patch. - + Args: patch_content: The patch text in V4A format - + Returns: Tuple of (operations, error_message) - If successful: (list_of_operations, None) - If failed: ([], error_description) """ - lines = patch_content.split('\n') - operations: List[PatchOperation] = [] - + lines = patch_content.split("\n") + operations: list[PatchOperation] = [] + # Find patch boundaries start_idx = None end_idx = None - + for i, line in enumerate(lines): - if '*** Begin Patch' in line or '***Begin Patch' in line: + if "*** Begin Patch" in line or "***Begin Patch" in line: start_idx = i - elif '*** End Patch' in line or '***End Patch' in line: + elif "*** End Patch" in line or "***End Patch" in line: end_idx = i break - + if start_idx is None: # Try to parse without explicit begin marker start_idx = -1 - + if end_idx is None: end_idx = len(lines) - + # Parse operations between boundaries i = start_idx + 1 - current_op: Optional[PatchOperation] = None - current_hunk: Optional[Hunk] = None - + current_op: PatchOperation | None = None + current_hunk: Hunk | None = None + while i < end_idx: line = lines[i] - + # Check for file operation markers - update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line) - add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line) - delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line) - move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line) - + update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line) + add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line) + delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line) + move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line) + if update_match: # Save previous operation if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) operations.append(current_op) - - current_op = PatchOperation( - operation=OperationType.UPDATE, - file_path=update_match.group(1).strip() - ) + + current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip()) current_hunk = None - + elif add_match: if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) operations.append(current_op) - - current_op = PatchOperation( - operation=OperationType.ADD, - file_path=add_match.group(1).strip() - ) + + current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip()) current_hunk = Hunk() - + elif delete_match: if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) operations.append(current_op) - - current_op = PatchOperation( - operation=OperationType.DELETE, - file_path=delete_match.group(1).strip() - ) + + current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip()) operations.append(current_op) current_op = None current_hunk = None - + elif move_match: if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) operations.append(current_op) - + current_op = PatchOperation( operation=OperationType.MOVE, file_path=move_match.group(1).strip(), - new_path=move_match.group(2).strip() + new_path=move_match.group(2).strip(), ) operations.append(current_op) current_op = None current_hunk = None - - elif line.startswith('@@'): + + elif line.startswith("@@"): # Context hint / hunk marker if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) - + # Extract context hint - hint_match = re.match(r'@@\s*(.+?)\s*@@', line) + hint_match = re.match(r"@@\s*(.+?)\s*@@", line) hint = hint_match.group(1) if hint_match else None current_hunk = Hunk(context_hint=hint) - + elif current_op and line: # Parse hunk line if current_hunk is None: current_hunk = Hunk() - - if line.startswith('+'): - current_hunk.lines.append(HunkLine('+', line[1:])) - elif line.startswith('-'): - current_hunk.lines.append(HunkLine('-', line[1:])) - elif line.startswith(' '): - current_hunk.lines.append(HunkLine(' ', line[1:])) - elif line.startswith('\\'): + + if line.startswith("+"): + current_hunk.lines.append(HunkLine("+", line[1:])) + elif line.startswith("-"): + current_hunk.lines.append(HunkLine("-", line[1:])) + elif line.startswith(" "): + current_hunk.lines.append(HunkLine(" ", line[1:])) + elif line.startswith("\\"): # "\ No newline at end of file" marker - skip pass else: # Treat as context line (implicit space prefix) - current_hunk.lines.append(HunkLine(' ', line)) - + current_hunk.lines.append(HunkLine(" ", line)) + i += 1 - + # Don't forget the last operation if current_op: if current_hunk and current_hunk.lines: current_op.hunks.append(current_hunk) operations.append(current_op) - + return operations, None -def apply_v4a_operations(operations: List[PatchOperation], - file_ops: Any) -> 'PatchResult': +def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult": """ Apply V4A patch operations using a file operations interface. - + Args: operations: List of PatchOperation from parse_v4a_patch file_ops: Object with read_file, write_file methods - + Returns: PatchResult with results of all operations """ # Import here to avoid circular imports from tools.file_operations import PatchResult - + files_modified = [] files_created = [] files_deleted = [] all_diffs = [] errors = [] - + for op in operations: try: if op.operation == OperationType.ADD: @@ -236,7 +229,7 @@ def apply_v4a_operations(operations: List[PatchOperation], all_diffs.append(result[1]) else: errors.append(f"Failed to add {op.file_path}: {result[1]}") - + elif op.operation == OperationType.DELETE: result = _apply_delete(op, file_ops) if result[0]: @@ -244,7 +237,7 @@ def apply_v4a_operations(operations: List[PatchOperation], all_diffs.append(result[1]) else: errors.append(f"Failed to delete {op.file_path}: {result[1]}") - + elif op.operation == OperationType.MOVE: result = _apply_move(op, file_ops) if result[0]: @@ -252,7 +245,7 @@ def apply_v4a_operations(operations: List[PatchOperation], all_diffs.append(result[1]) else: errors.append(f"Failed to move {op.file_path}: {result[1]}") - + elif op.operation == OperationType.UPDATE: result = _apply_update(op, file_ops) if result[0]: @@ -260,19 +253,19 @@ def apply_v4a_operations(operations: List[PatchOperation], all_diffs.append(result[1]) else: errors.append(f"Failed to update {op.file_path}: {result[1]}") - + except Exception as e: errors.append(f"Error processing {op.file_path}: {str(e)}") - + # Run lint on all modified/created files lint_results = {} for f in files_modified + files_created: - if hasattr(file_ops, '_check_lint'): + if hasattr(file_ops, "_check_lint"): lint_result = file_ops._check_lint(f) lint_results[f] = lint_result.to_dict() - - combined_diff = '\n'.join(all_diffs) - + + combined_diff = "\n".join(all_diffs) + if errors: return PatchResult( success=False, @@ -281,123 +274,124 @@ def apply_v4a_operations(operations: List[PatchOperation], files_created=files_created, files_deleted=files_deleted, lint=lint_results if lint_results else None, - error='; '.join(errors) + error="; ".join(errors), ) - + return PatchResult( success=True, diff=combined_diff, files_modified=files_modified, files_created=files_created, files_deleted=files_deleted, - lint=lint_results if lint_results else None + lint=lint_results if lint_results else None, ) -def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: +def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]: """Apply an add file operation.""" # Extract content from hunks (all + lines) content_lines = [] for hunk in op.hunks: for line in hunk.lines: - if line.prefix == '+': + if line.prefix == "+": content_lines.append(line.content) - - content = '\n'.join(content_lines) - + + content = "\n".join(content_lines) + result = file_ops.write_file(op.file_path, content) if result.error: return False, result.error - + diff = f"--- /dev/null\n+++ b/{op.file_path}\n" - diff += '\n'.join(f"+{line}" for line in content_lines) - + diff += "\n".join(f"+{line}" for line in content_lines) + return True, diff -def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: +def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]: """Apply a delete file operation.""" # Read file first for diff read_result = file_ops.read_file(op.file_path) - + if read_result.error and "not found" in read_result.error.lower(): # File doesn't exist, nothing to delete return True, f"# {op.file_path} already deleted or doesn't exist" - + # Delete directly via shell command using the underlying environment rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}") - + if rm_result.exit_code != 0: return False, rm_result.stdout - + diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted" return True, diff -def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: +def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]: """Apply a move file operation.""" # Use shell mv command mv_result = file_ops._exec( f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}" ) - + if mv_result.exit_code != 0: return False, mv_result.stdout - + diff = f"# Moved: {op.file_path} -> {op.new_path}" return True, diff -def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: +def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]: """Apply an update file operation.""" # Read current content read_result = file_ops.read_file(op.file_path, limit=10000) - + if read_result.error: return False, f"Cannot read file: {read_result.error}" - + # Parse content (remove line numbers) current_lines = [] - for line in read_result.content.split('\n'): - if '|' in line: + for line in read_result.content.split("\n"): + if "|" in line: # Line format: " 123|content" - parts = line.split('|', 1) + parts = line.split("|", 1) if len(parts) == 2: current_lines.append(parts[1]) else: current_lines.append(line) else: current_lines.append(line) - - current_content = '\n'.join(current_lines) - + + current_content = "\n".join(current_lines) + # Apply each hunk new_content = current_content - + for hunk in op.hunks: # Build search pattern from context and removed lines search_lines = [] replace_lines = [] - + for line in hunk.lines: - if line.prefix == ' ': + if line.prefix == " ": search_lines.append(line.content) replace_lines.append(line.content) - elif line.prefix == '-': + elif line.prefix == "-": search_lines.append(line.content) - elif line.prefix == '+': + elif line.prefix == "+": replace_lines.append(line.content) - + if search_lines: - search_pattern = '\n'.join(search_lines) - replacement = '\n'.join(replace_lines) - + search_pattern = "\n".join(search_lines) + replacement = "\n".join(replace_lines) + # Use fuzzy matching from tools.fuzzy_match import fuzzy_find_and_replace + new_content, count, error = fuzzy_find_and_replace( new_content, search_pattern, replacement, replace_all=False ) - + if error and count == 0: # Try with context hint if available if hunk.context_hint: @@ -408,31 +402,32 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: window_start = max(0, hint_pos - 500) window_end = min(len(new_content), hint_pos + 2000) window = new_content[window_start:window_end] - + window_new, count, error = fuzzy_find_and_replace( window, search_pattern, replacement, replace_all=False ) - + if count > 0: new_content = new_content[:window_start] + window_new + new_content[window_end:] error = None - + if error: return False, f"Could not apply hunk: {error}" - + # Write new content write_result = file_ops.write_file(op.file_path, new_content) if write_result.error: return False, write_result.error - + # Generate diff import difflib + diff_lines = difflib.unified_diff( current_content.splitlines(keepends=True), new_content.splitlines(keepends=True), fromfile=f"a/{op.file_path}", - tofile=f"b/{op.file_path}" + tofile=f"b/{op.file_path}", ) - diff = ''.join(diff_lines) - + diff = "".join(diff_lines) + return True, diff diff --git a/tools/process_registry.py b/tools/process_registry.py index 948f2a4f30..9182ca968e 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -34,7 +34,6 @@ import logging import os import platform import shlex -import shutil import signal import subprocess import threading @@ -42,10 +41,11 @@ import time import uuid _IS_WINDOWS = platform.system() == "Windows" -from tools.environments.local import _find_shell from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any + +from tools.environments.local import _find_shell logger = logging.getLogger(__name__) @@ -54,30 +54,31 @@ logger = logging.getLogger(__name__) CHECKPOINT_PATH = Path(os.path.expanduser("~/.hermes/processes.json")) # Limits -MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer -FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes -MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning) +MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer +FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes +MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning) @dataclass class ProcessSession: """A tracked background process with output buffering.""" - id: str # Unique session ID ("proc_xxxxxxxxxxxx") - command: str # Original command string - task_id: str = "" # Task/sandbox isolation key - session_key: str = "" # Gateway session key (for reset protection) - pid: Optional[int] = None # OS process ID - process: Optional[subprocess.Popen] = None # Popen handle (local only) - env_ref: Any = None # Reference to the environment object - cwd: Optional[str] = None # Working directory - started_at: float = 0.0 # time.time() of spawn - exited: bool = False # Whether the process has finished - exit_code: Optional[int] = None # Exit code (None if still running) - output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS) + + id: str # Unique session ID ("proc_xxxxxxxxxxxx") + command: str # Original command string + task_id: str = "" # Task/sandbox isolation key + session_key: str = "" # Gateway session key (for reset protection) + pid: int | None = None # OS process ID + process: subprocess.Popen | None = None # Popen handle (local only) + env_ref: Any = None # Reference to the environment object + cwd: str | None = None # Working directory + started_at: float = 0.0 # time.time() of spawn + exited: bool = False # Whether the process has finished + exit_code: int | None = None # Exit code (None if still running) + output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS) max_output_chars: int = MAX_OUTPUT_CHARS - detached: bool = False # True if recovered from crash (no pipe) + detached: bool = False # True if recovered from crash (no pipe) _lock: threading.Lock = field(default_factory=threading.Lock) - _reader_thread: Optional[threading.Thread] = field(default=None, repr=False) + _reader_thread: threading.Thread | None = field(default=None, repr=False) _pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True) @@ -100,12 +101,12 @@ class ProcessRegistry: ) def __init__(self): - self._running: Dict[str, ProcessSession] = {} - self._finished: Dict[str, ProcessSession] = {} + self._running: dict[str, ProcessSession] = {} + self._finished: dict[str, ProcessSession] = {} self._lock = threading.Lock() # Side-channel for check_interval watchers (gateway reads after agent run) - self.pending_watchers: List[Dict[str, Any]] = [] + self.pending_watchers: list[dict[str, Any]] = [] @staticmethod def _clean_shell_noise(text: str) -> str: @@ -149,6 +150,7 @@ class ProcessRegistry: # Try PTY mode for interactive CLI tools try: import ptyprocess + user_shell = _find_shell() pty_env = os.environ | (env_vars or {}) pty_env["PYTHONUNBUFFERED"] = "1" @@ -260,10 +262,7 @@ class ProcessRegistry: log_path = f"/tmp/hermes_bg_{session.id}.log" pid_path = f"/tmp/hermes_bg_{session.id}.pid" quoted_command = shlex.quote(command) - bg_command = ( - f"nohup bash -c {quoted_command} > {log_path} 2>&1 & " - f"echo $! > {pid_path} && cat {pid_path}" - ) + bg_command = f"nohup bash -c {quoted_command} > {log_path} 2>&1 & echo $! > {pid_path} && cat {pid_path}" try: result = env.execute(bg_command, timeout=timeout) @@ -313,7 +312,7 @@ class ProcessRegistry: with session._lock: session.output_buffer += chunk if len(session.output_buffer) > session.max_output_chars: - session.output_buffer = session.output_buffer[-session.max_output_chars:] + session.output_buffer = session.output_buffer[-session.max_output_chars :] except Exception as e: logger.debug("Process stdout reader ended: %s", e) @@ -326,9 +325,7 @@ class ProcessRegistry: session.exit_code = session.process.returncode self._move_to_finished(session) - def _env_poller_loop( - self, session: ProcessSession, env: Any, log_path: str, pid_path: str - ): + def _env_poller_loop(self, session: ProcessSession, env: Any, log_path: str, pid_path: str): """Background thread: poll a sandbox log file for non-local backends.""" while not session.exited: time.sleep(2) # Poll every 2 seconds @@ -340,7 +337,7 @@ class ProcessRegistry: with session._lock: session.output_buffer = new_output if len(session.output_buffer) > session.max_output_chars: - session.output_buffer = session.output_buffer[-session.max_output_chars:] + session.output_buffer = session.output_buffer[-session.max_output_chars :] # Check if process is still running check = env.execute( @@ -383,7 +380,7 @@ class ProcessRegistry: with session._lock: session.output_buffer += text if len(session.output_buffer) > session.max_output_chars: - session.output_buffer = session.output_buffer[-session.max_output_chars:] + session.output_buffer = session.output_buffer[-session.max_output_chars :] except EOFError: break except Exception: @@ -397,7 +394,7 @@ class ProcessRegistry: except Exception as e: logger.debug("PTY wait timed out or failed: %s", e) session.exited = True - session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1 + session.exit_code = pty.exitstatus if hasattr(pty, "exitstatus") else -1 self._move_to_finished(session) def _move_to_finished(self, session: ProcessSession): @@ -409,7 +406,7 @@ class ProcessRegistry: # ----- Query Methods ----- - def get(self, session_id: str) -> Optional[ProcessSession]: + def get(self, session_id: str) -> ProcessSession | None: """Get a session by ID (running or finished).""" with self._lock: return self._running.get(session_id) or self._finished.get(session_id) @@ -454,7 +451,7 @@ class ProcessRegistry: if offset == 0 and limit > 0: selected = lines[-limit:] else: - selected = lines[offset:offset + limit] + selected = lines[offset : offset + limit] return { "session_id": session.id, @@ -485,10 +482,7 @@ class ProcessRegistry: if requested_timeout and requested_timeout > max_timeout: effective_timeout = max_timeout - timeout_note = ( - f"Requested wait of {requested_timeout}s was clamped " - f"to configured limit of {max_timeout}s" - ) + timeout_note = f"Requested wait of {requested_timeout}s was clamped to configured limit of {max_timeout}s" else: effective_timeout = requested_timeout or max_timeout @@ -581,7 +575,7 @@ class ProcessRegistry: return {"status": "already_exited", "error": "Process has already finished"} # PTY mode -- write through pty handle (expects bytes) - if hasattr(session, '_pty') and session._pty: + if hasattr(session, "_pty") and session._pty: try: pty_data = data.encode("utf-8") if isinstance(data, str) else data session._pty.write(pty_data) @@ -635,26 +629,17 @@ class ProcessRegistry: def has_active_processes(self, task_id: str) -> bool: """Check if there are active (running) processes for a task_id.""" with self._lock: - return any( - s.task_id == task_id and not s.exited - for s in self._running.values() - ) + return any(s.task_id == task_id and not s.exited for s in self._running.values()) def has_active_for_session(self, session_key: str) -> bool: """Check if there are active processes for a gateway session key.""" with self._lock: - return any( - s.session_key == session_key and not s.exited - for s in self._running.values() - ) + return any(s.session_key == session_key and not s.exited for s in self._running.values()) def kill_all(self, task_id: str = None) -> int: """Kill all running processes, optionally filtered by task_id. Returns count killed.""" with self._lock: - targets = [ - s for s in self._running.values() - if (task_id is None or s.task_id == task_id) and not s.exited - ] + targets = [s for s in self._running.values() if (task_id is None or s.task_id == task_id) and not s.exited] killed = 0 for session in targets: @@ -669,10 +654,7 @@ class ProcessRegistry: """Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock.""" # First prune expired finished sessions now = time.time() - expired = [ - sid for sid, s in self._finished.items() - if (now - s.started_at) > FINISHED_TTL_SECONDS - ] + expired = [sid for sid, s in self._finished.items() if (now - s.started_at) > FINISHED_TTL_SECONDS] for sid in expired: del self._finished[sid] @@ -696,18 +678,21 @@ class ProcessRegistry: entries = [] for s in self._running.values(): if not s.exited: - entries.append({ - "session_id": s.id, - "command": s.command, - "pid": s.pid, - "cwd": s.cwd, - "started_at": s.started_at, - "task_id": s.task_id, - "session_key": s.session_key, - }) - + entries.append( + { + "session_id": s.id, + "command": s.command, + "pid": s.pid, + "cwd": s.cwd, + "started_at": s.started_at, + "task_id": s.task_id, + "session_key": s.session_key, + } + ) + # Atomic write to avoid corruption on crash from utils import atomic_json_write + atomic_json_write(CHECKPOINT_PATH, entries) except Exception as e: logger.debug("Failed to write checkpoint file: %s", e, exc_info=True) @@ -759,6 +744,7 @@ class ProcessRegistry: # Clear the checkpoint (will be rewritten as processes finish) try: from utils import atomic_json_write + atomic_json_write(CHECKPOINT_PATH, []) except Exception as e: logger.debug("Could not clear checkpoint file: %s", e, exc_info=True) @@ -790,38 +776,32 @@ PROCESS_SCHEMA = { "action": { "type": "string", "enum": ["list", "poll", "log", "wait", "kill", "write", "submit"], - "description": "Action to perform on background processes" + "description": "Action to perform on background processes", }, "session_id": { "type": "string", - "description": "Process session ID (from terminal background output). Required for all actions except 'list'." + "description": "Process session ID (from terminal background output). Required for all actions except 'list'.", }, "data": { "type": "string", - "description": "Text to send to process stdin (for 'write' and 'submit' actions)" + "description": "Text to send to process stdin (for 'write' and 'submit' actions)", }, "timeout": { "type": "integer", "description": "Max seconds to block for 'wait' action. Returns partial output on timeout.", - "minimum": 1 + "minimum": 1, }, - "offset": { - "type": "integer", - "description": "Line offset for 'log' action (default: last 200 lines)" - }, - "limit": { - "type": "integer", - "description": "Max lines to return for 'log' action", - "minimum": 1 - } + "offset": {"type": "integer", "description": "Line offset for 'log' action (default: last 200 lines)"}, + "limit": {"type": "integer", "description": "Max lines to return for 'log' action", "minimum": 1}, }, - "required": ["action"] - } + "required": ["action"], + }, } def _handle_process(args, **kw): import json as _json + task_id = kw.get("task_id") action = args.get("action", "") # Coerce to string — some models send session_id as an integer @@ -835,8 +815,10 @@ def _handle_process(args, **kw): if action == "poll": return _json.dumps(process_registry.poll(session_id), ensure_ascii=False) elif action == "log": - return _json.dumps(process_registry.read_log( - session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False) + return _json.dumps( + process_registry.read_log(session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), + ensure_ascii=False, + ) elif action == "wait": return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False) elif action == "kill": @@ -845,7 +827,10 @@ def _handle_process(args, **kw): return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False) elif action == "submit": return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False) - return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False) + return _json.dumps( + {"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, + ensure_ascii=False, + ) registry.register( diff --git a/tools/registry.py b/tools/registry.py index b56cb64c3d..bd22a72643 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -16,7 +16,7 @@ Import chain (circular-import safe): import json import logging -from typing import Any, Callable, Dict, List, Optional, Set +from collections.abc import Callable logger = logging.getLogger(__name__) @@ -25,12 +25,17 @@ class ToolEntry: """Metadata for a single registered tool.""" __slots__ = ( - "name", "toolset", "schema", "handler", "check_fn", - "requires_env", "is_async", "description", + "name", + "toolset", + "schema", + "handler", + "check_fn", + "requires_env", + "is_async", + "description", ) - def __init__(self, name, toolset, schema, handler, check_fn, - requires_env, is_async, description): + def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description): self.name = name self.toolset = toolset self.schema = schema @@ -45,8 +50,8 @@ class ToolRegistry: """Singleton registry that collects tool schemas + handlers from tool files.""" def __init__(self): - self._tools: Dict[str, ToolEntry] = {} - self._toolset_checks: Dict[str, Callable] = {} + self._tools: dict[str, ToolEntry] = {} + self._toolset_checks: dict[str, Callable] = {} # ------------------------------------------------------------------ # Registration @@ -81,7 +86,7 @@ class ToolRegistry: # Schema retrieval # ------------------------------------------------------------------ - def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]: + def get_definitions(self, tool_names: set[str], quiet: bool = False) -> list[dict]: """Return OpenAI-format tool schemas for the requested tool names. Only tools whose ``check_fn()`` returns True (or have no check_fn) @@ -122,6 +127,7 @@ class ToolRegistry: try: if entry.is_async: from model_tools import _run_async + return _run_async(entry.handler(args, **kwargs)) return entry.handler(args, **kwargs) except Exception as e: @@ -132,16 +138,16 @@ class ToolRegistry: # Query helpers (replace redundant dicts in model_tools.py) # ------------------------------------------------------------------ - def get_all_tool_names(self) -> List[str]: + def get_all_tool_names(self) -> list[str]: """Return sorted list of all registered tool names.""" return sorted(self._tools.keys()) - def get_toolset_for_tool(self, name: str) -> Optional[str]: + def get_toolset_for_tool(self, name: str) -> str | None: """Return the toolset a tool belongs to, or None.""" entry = self._tools.get(name) return entry.toolset if entry else None - def get_tool_to_toolset_map(self) -> Dict[str, str]: + def get_tool_to_toolset_map(self) -> dict[str, str]: """Return ``{tool_name: toolset_name}`` for every registered tool.""" return {name: e.toolset for name, e in self._tools.items()} @@ -160,14 +166,14 @@ class ToolRegistry: logger.debug("Toolset %s check raised; marking unavailable", toolset) return False - def check_toolset_requirements(self) -> Dict[str, bool]: + def check_toolset_requirements(self) -> dict[str, bool]: """Return ``{toolset: available_bool}`` for every toolset.""" toolsets = set(e.toolset for e in self._tools.values()) return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)} - def get_available_toolsets(self) -> Dict[str, dict]: + def get_available_toolsets(self) -> dict[str, dict]: """Return toolset metadata for UI display.""" - toolsets: Dict[str, dict] = {} + toolsets: dict[str, dict] = {} for entry in self._tools.values(): ts = entry.toolset if ts not in toolsets: @@ -184,9 +190,9 @@ class ToolRegistry: toolsets[ts]["requirements"].append(env) return toolsets - def get_toolset_requirements(self) -> Dict[str, dict]: + def get_toolset_requirements(self) -> dict[str, dict]: """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat.""" - result: Dict[str, dict] = {} + result: dict[str, dict] = {} for entry in self._tools.values(): ts = entry.toolset if ts not in result: @@ -217,11 +223,13 @@ class ToolRegistry: if self.is_toolset_available(ts): available.append(ts) else: - unavailable.append({ - "name": ts, - "env_vars": entry.requires_env, - "tools": [e.name for e in self._tools.values() if e.toolset == ts], - }) + unavailable.append( + { + "name": ts, + "env_vars": entry.requires_env, + "tools": [e.name for e in self._tools.values() if e.toolset == ts], + } + ) return available, unavailable diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 6ffa6e2379..7eac6d1bef 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -37,11 +37,12 @@ import subprocess import sys import time import uuid +from dataclasses import dataclass from datetime import datetime -import yaml -from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any + +import yaml # ============================================================================ # Path Configuration @@ -106,9 +107,11 @@ LOCKED_FIELD_NAMES = set(LOCKED_FIELDS.get("env", {}).keys()) # State Management # ============================================================================ + @dataclass class EnvironmentInfo: """Information about a discovered environment.""" + name: str class_name: str file_path: str @@ -119,27 +122,28 @@ class EnvironmentInfo: @dataclass class RunState: """State for a training run.""" + run_id: str environment: str - config: Dict[str, Any] + config: dict[str, Any] status: str = "pending" # pending, starting, running, stopping, stopped, completed, failed error_message: str = "" wandb_project: str = "" wandb_run_name: str = "" start_time: float = 0.0 # Process handles - api_process: Optional[subprocess.Popen] = None - trainer_process: Optional[subprocess.Popen] = None - env_process: Optional[subprocess.Popen] = None + api_process: subprocess.Popen | None = None + trainer_process: subprocess.Popen | None = None + env_process: subprocess.Popen | None = None # Global state -_environments: List[EnvironmentInfo] = [] -_current_env: Optional[str] = None -_current_config: Dict[str, Any] = {} -_env_config_cache: Dict[str, Dict[str, Dict[str, Any]]] = {} -_active_runs: Dict[str, RunState] = {} -_last_status_check: Dict[str, float] = {} +_environments: list[EnvironmentInfo] = [] +_current_env: str | None = None +_current_config: dict[str, Any] = {} +_env_config_cache: dict[str, dict[str, dict[str, Any]]] = {} +_active_runs: dict[str, RunState] = {} +_last_status_check: dict[str, float] = {} # Rate limiting for status checks (30 minutes) MIN_STATUS_CHECK_INTERVAL = 30 * 60 @@ -149,23 +153,24 @@ MIN_STATUS_CHECK_INTERVAL = 30 * 60 # Environment Discovery # ============================================================================ -def _scan_environments() -> List[EnvironmentInfo]: + +def _scan_environments() -> list[EnvironmentInfo]: """ Scan the environments directory for BaseEnv subclasses using AST. """ environments = [] - + if not ENVIRONMENTS_DIR.exists(): return environments - + for py_file in ENVIRONMENTS_DIR.glob("*.py"): if py_file.name.startswith("_"): continue - + try: - with open(py_file, "r") as f: + with open(py_file) as f: tree = ast.parse(f.read()) - + for node in ast.walk(tree): if isinstance(node, ast.ClassDef): # Check if class has BaseEnv as base @@ -175,13 +180,13 @@ def _scan_environments() -> List[EnvironmentInfo]: base_name = base.id elif isinstance(base, ast.Attribute): base_name = base.attr - + if base_name == "BaseEnv": # Extract name from class attribute if present env_name = py_file.stem description = "" config_class = "BaseEnvConfig" - + for item in node.body: if isinstance(item, ast.Assign): for target in item.targets: @@ -190,30 +195,32 @@ def _scan_environments() -> List[EnvironmentInfo]: env_name = item.value.value elif target.id == "env_config_cls" and isinstance(item.value, ast.Name): config_class = item.value.id - + # Get docstring if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant): if isinstance(item.value.value, str) and not description: description = item.value.value.split("\n")[0].strip() - - environments.append(EnvironmentInfo( - name=env_name, - class_name=node.name, - file_path=str(py_file), - description=description or f"Environment from {py_file.name}", - config_class=config_class, - )) + + environments.append( + EnvironmentInfo( + name=env_name, + class_name=node.name, + file_path=str(py_file), + description=description or f"Environment from {py_file.name}", + config_class=config_class, + ) + ) break except Exception as e: print(f"Warning: Could not parse {py_file}: {e}") - + return environments -def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: +def _get_env_config_fields(env_file_path: str) -> dict[str, dict[str, Any]]: """ Dynamically import an environment and extract its config fields. - + Uses config_init() to get the actual config class, with fallback to directly importing BaseEnvConfig if config_init fails. """ @@ -223,18 +230,18 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: module = importlib.util.module_from_spec(spec) sys.modules["env_module"] = module spec.loader.exec_module(module) - + # Find the BaseEnv subclass env_class = None for name, obj in vars(module).items(): if isinstance(obj, type) and name != "BaseEnv": - if hasattr(obj, "config_init") and callable(getattr(obj, "config_init")): + if hasattr(obj, "config_init") and callable(obj.config_init): env_class = obj break - + if not env_class: return {} - + # Try calling config_init to get the actual config class config_class = None try: @@ -245,40 +252,41 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: print(f"Note: config_init failed ({config_error}), using BaseEnvConfig defaults") try: from atroposlib.envs.base import BaseEnvConfig + config_class = BaseEnvConfig except ImportError: return {} - + if not config_class: return {} - + # Helper to make values JSON-serializable (handle enums, etc.) def make_serializable(val): if val is None: return None - if hasattr(val, 'value'): # Enum + if hasattr(val, "value"): # Enum return val.value - if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)): + if hasattr(val, "name") and hasattr(val, "__class__") and "Enum" in str(type(val)): return val.name return val - + # Extract fields from the Pydantic model fields = {} for field_name, field_info in config_class.model_fields.items(): field_type = field_info.annotation default = make_serializable(field_info.default) description = field_info.description or "" - + is_locked = field_name in LOCKED_FIELD_NAMES - + # Convert type to string type_name = getattr(field_type, "__name__", str(field_type)) if hasattr(field_type, "__origin__"): type_name = str(field_type) - + locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default) current_value = make_serializable(locked_value) if is_locked else default - + fields[field_name] = { "type": type_name, "default": default, @@ -286,9 +294,9 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: "locked": is_locked, "current_value": current_value, } - + return fields - + except Exception as e: print(f"Warning: Could not introspect environment config: {e}") return {} @@ -305,6 +313,7 @@ def _initialize_environments(): # Subprocess Management # ============================================================================ + async def _spawn_training_run(run_state: RunState, config_path: Path): """ Spawn the three processes needed for training: @@ -313,16 +322,16 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): 3. environment.py serve (the Atropos environment) """ run_id = run_state.run_id - + # Log file paths api_log = LOGS_DIR / f"api_{run_id}.log" trainer_log = LOGS_DIR / f"trainer_{run_id}.log" env_log = LOGS_DIR / f"env_{run_id}.log" - + try: # Step 1: Start the Atropos API server (run-api) print(f"[{run_id}] Starting Atropos API server (run-api)...") - + api_log_file = open(api_log, "w") run_state.api_process = subprocess.Popen( ["run-api"], @@ -330,20 +339,20 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): stderr=subprocess.STDOUT, cwd=str(TINKER_ATROPOS_ROOT), ) - + # Wait for API to start await asyncio.sleep(5) - + if run_state.api_process.poll() is not None: run_state.status = "failed" run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}" return - + print(f"[{run_id}] Atropos API server started") - + # Step 2: Start the Tinker trainer print(f"[{run_id}] Starting Tinker trainer: launch_training.py --config {config_path}") - + trainer_log_file = open(trainer_log, "w") run_state.trainer_process = subprocess.Popen( [sys.executable, "launch_training.py", "--config", str(config_path)], @@ -352,38 +361,40 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): cwd=str(TINKER_ATROPOS_ROOT), env={**os.environ, "TINKER_API_KEY": os.getenv("TINKER_API_KEY", "")}, ) - + # Wait for trainer to initialize (it starts FastAPI inference server on 8001) print(f"[{run_id}] Waiting 30 seconds for trainer to initialize...") await asyncio.sleep(30) - + if run_state.trainer_process.poll() is not None: run_state.status = "failed" - run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}" + run_state.error_message = ( + f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}" + ) if run_state.api_process: run_state.api_process.terminate() return - + print(f"[{run_id}] Trainer started, inference server on port 8001") - + # Step 3: Start the environment print(f"[{run_id}] Waiting 90 more seconds before starting environment...") await asyncio.sleep(90) - + # Find the environment file env_info = None for env in _environments: if env.name == run_state.environment: env_info = env break - + if not env_info: run_state.status = "failed" run_state.error_message = f"Environment '{run_state.environment}' not found" return - + print(f"[{run_id}] Starting environment: {env_info.file_path} serve") - + env_log_file = open(env_log, "w") run_state.env_process = subprocess.Popen( [sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)], @@ -391,26 +402,28 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): stderr=subprocess.STDOUT, cwd=str(TINKER_ATROPOS_ROOT), ) - + # Wait for environment to connect await asyncio.sleep(10) - + if run_state.env_process.poll() is not None: run_state.status = "failed" - run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}" + run_state.error_message = ( + f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}" + ) if run_state.trainer_process: run_state.trainer_process.terminate() if run_state.api_process: run_state.api_process.terminate() return - + run_state.status = "running" run_state.start_time = time.time() print(f"[{run_id}] Training run started successfully!") - + # Start background monitoring asyncio.create_task(_monitor_training_run(run_state)) - + except Exception as e: run_state.status = "failed" run_state.error_message = str(e) @@ -421,7 +434,7 @@ async def _monitor_training_run(run_state: RunState): """Background task to monitor a training run.""" while run_state.status == "running": await asyncio.sleep(30) # Check every 30 seconds - + # Check if any process has died if run_state.env_process and run_state.env_process.poll() is not None: exit_code = run_state.env_process.returncode @@ -432,7 +445,7 @@ async def _monitor_training_run(run_state: RunState): run_state.error_message = f"Environment process exited with code {exit_code}" _stop_training_run(run_state) break - + if run_state.trainer_process and run_state.trainer_process.poll() is not None: exit_code = run_state.trainer_process.returncode if exit_code == 0: @@ -442,10 +455,10 @@ async def _monitor_training_run(run_state: RunState): run_state.error_message = f"Trainer process exited with code {exit_code}" _stop_training_run(run_state) break - + if run_state.api_process and run_state.api_process.poll() is not None: run_state.status = "failed" - run_state.error_message = f"API server exited unexpectedly" + run_state.error_message = "API server exited unexpectedly" _stop_training_run(run_state) break @@ -460,7 +473,7 @@ def _stop_training_run(run_state: RunState): run_state.env_process.wait(timeout=10) except subprocess.TimeoutExpired: run_state.env_process.kill() - + if run_state.trainer_process and run_state.trainer_process.poll() is None: print(f"[{run_state.run_id}] Stopping trainer process...") run_state.trainer_process.terminate() @@ -468,7 +481,7 @@ def _stop_training_run(run_state: RunState): run_state.trainer_process.wait(timeout=10) except subprocess.TimeoutExpired: run_state.trainer_process.kill() - + if run_state.api_process and run_state.api_process.poll() is None: print(f"[{run_state.run_id}] Stopping API server...") run_state.api_process.terminate() @@ -476,7 +489,7 @@ def _stop_training_run(run_state: RunState): run_state.api_process.wait(timeout=10) except subprocess.TimeoutExpired: run_state.api_process.kill() - + if run_state.status == "running": run_state.status = "stopped" @@ -485,30 +498,31 @@ def _stop_training_run(run_state: RunState): # Environment Discovery Tools # ============================================================================ + async def rl_list_environments() -> str: """ List all available RL environments. - + Scans tinker-atropos/tinker_atropos/environments/ for Python files containing classes that inherit from BaseEnv. - + Returns information about each environment including: - name: Environment identifier - class_name: Python class name - file_path: Path to the environment file - description: Brief description if available - + TIP: To create or modify RL environments: 1. Use terminal/file tools to inspect existing environments 2. Study how they load datasets, define verifiers, and structure rewards 3. Inspect HuggingFace datasets to understand data formats 4. Copy an existing environment as a template - + Returns: JSON string with list of environments """ _initialize_environments() - + response = { "environments": [ { @@ -524,95 +538,105 @@ async def rl_list_environments() -> str: "Use rl_select_environment(name) to select an environment", "Read the file_path with file tools to understand how each environment works", "Look for load_dataset(), score_answer(), get_next_item() methods", - ] + ], } - + return json.dumps(response, indent=2) async def rl_select_environment(name: str) -> str: """ Select an RL environment for training. - + This loads the environment's configuration fields into memory. After selecting, use rl_get_current_config() to see all configurable options and rl_edit_config() to modify specific fields. - + Args: name: Name of the environment to select (from rl_list_environments) - + Returns: JSON string with selection result, file path, and configurable field count - + TIP: Read the returned file_path to understand how the environment works. """ global _current_env, _current_config, _env_config_cache - + _initialize_environments() - + env_info = None for env in _environments: if env.name == name: env_info = env break - + if not env_info: - return json.dumps({ - "error": f"Environment '{name}' not found", - "available": [e.name for e in _environments], - }, indent=2) - + return json.dumps( + { + "error": f"Environment '{name}' not found", + "available": [e.name for e in _environments], + }, + indent=2, + ) + _current_env = name - + # Dynamically discover config fields config_fields = _get_env_config_fields(env_info.file_path) _env_config_cache[name] = config_fields - + # Initialize current config with defaults for non-locked fields _current_config = {} for field_name, field_info in config_fields.items(): if not field_info.get("locked", False): _current_config[field_name] = field_info.get("default") - + # Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") _current_config["wandb_name"] = f"{name}-{timestamp}" - - return json.dumps({ - "message": f"Selected environment: {name}", - "environment": name, - "file_path": env_info.file_path, - }, indent=2) + + return json.dumps( + { + "message": f"Selected environment: {name}", + "environment": name, + "file_path": env_info.file_path, + }, + indent=2, + ) # ============================================================================ # Configuration Tools # ============================================================================ + async def rl_get_current_config() -> str: """ Get the current environment configuration. - + Returns all configurable fields for the selected environment. Each environment may have different configuration options. - + Fields are divided into: - configurable_fields: Can be changed with rl_edit_config() - locked_fields: Infrastructure settings that cannot be changed - + Returns: JSON string with configurable and locked fields """ if not _current_env: - return json.dumps({ - "error": "No environment selected. Use rl_select_environment(name) first.", - }, indent=2) - + return json.dumps( + { + "error": "No environment selected. Use rl_select_environment(name) first.", + }, + indent=2, + ) + config_fields = _env_config_cache.get(_current_env, {}) - + configurable = [] locked = [] - + for field_name, field_info in config_fields.items(): field_data = { "name": field_name, @@ -621,148 +645,174 @@ async def rl_get_current_config() -> str: "description": field_info.get("description", ""), "current_value": _current_config.get(field_name, field_info.get("default")), } - + if field_info.get("locked", False): field_data["locked_value"] = LOCKED_FIELDS.get("env", {}).get(field_name) locked.append(field_data) else: configurable.append(field_data) - - return json.dumps({ - "environment": _current_env, - "configurable_fields": configurable, - "locked_fields": locked, - "tip": "Use rl_edit_config(field, value) to change any configurable field.", - }, indent=2) + + return json.dumps( + { + "environment": _current_env, + "configurable_fields": configurable, + "locked_fields": locked, + "tip": "Use rl_edit_config(field, value) to change any configurable field.", + }, + indent=2, + ) async def rl_edit_config(field: str, value: Any) -> str: """ Update a configuration field. - + Use rl_get_current_config() first to see available fields for the selected environment. Each environment has different options. - + Locked fields (infrastructure settings) cannot be changed. - + Args: field: Name of the field to update (from rl_get_current_config) value: New value for the field - + Returns: JSON string with updated config or error message """ global _current_config - + if not _current_env: - return json.dumps({ - "error": "No environment selected. Use rl_select_environment(name) first.", - }, indent=2) - + return json.dumps( + { + "error": "No environment selected. Use rl_select_environment(name) first.", + }, + indent=2, + ) + config_fields = _env_config_cache.get(_current_env, {}) - + if field not in config_fields: - return json.dumps({ - "error": f"Unknown field '{field}'", - "available_fields": list(config_fields.keys()), - }, indent=2) - + return json.dumps( + { + "error": f"Unknown field '{field}'", + "available_fields": list(config_fields.keys()), + }, + indent=2, + ) + field_info = config_fields[field] if field_info.get("locked", False): - return json.dumps({ - "error": f"Field '{field}' is locked and cannot be changed", - "locked_value": LOCKED_FIELDS.get("env", {}).get(field), - }, indent=2) - + return json.dumps( + { + "error": f"Field '{field}' is locked and cannot be changed", + "locked_value": LOCKED_FIELDS.get("env", {}).get(field), + }, + indent=2, + ) + _current_config[field] = value - - return json.dumps({ - "message": f"Updated {field} = {value}", - "field": field, - "value": value, - "config": _current_config, - }, indent=2) + + return json.dumps( + { + "message": f"Updated {field} = {value}", + "field": field, + "value": value, + "config": _current_config, + }, + indent=2, + ) # ============================================================================ # Training Management Tools # ============================================================================ + async def rl_start_training() -> str: """ Start a new RL training run with the current environment and config. - + Requires an environment to be selected first using rl_select_environment(). Use rl_edit_config() to adjust configuration before starting. - + This spawns three processes: 1. run-api (Atropos trajectory API) 2. launch_training.py (Tinker trainer + inference server) 3. environment.py serve (the selected environment) - + WARNING: Training runs take hours. Use rl_check_status() to monitor progress (recommended: check every 30 minutes at most). - + Returns: JSON string with run_id and initial status """ global _active_runs - + if not _current_env: - return json.dumps({ - "error": "No environment selected. Use rl_select_environment(name) first.", - }, indent=2) - + return json.dumps( + { + "error": "No environment selected. Use rl_select_environment(name) first.", + }, + indent=2, + ) + # Check API keys if not os.getenv("TINKER_API_KEY"): - return json.dumps({ - "error": "TINKER_API_KEY not set. Add it to ~/.hermes/.env", - }, indent=2) - + return json.dumps( + { + "error": "TINKER_API_KEY not set. Add it to ~/.hermes/.env", + }, + indent=2, + ) + # Find environment file env_info = None for env in _environments: if env.name == _current_env: env_info = env break - + if not env_info or not Path(env_info.file_path).exists(): - return json.dumps({ - "error": f"Environment file not found for '{_current_env}'", - }, indent=2) - + return json.dumps( + { + "error": f"Environment file not found for '{_current_env}'", + }, + indent=2, + ) + # Generate run ID run_id = str(uuid.uuid4())[:8] - + # Create config YAML CONFIGS_DIR.mkdir(exist_ok=True) config_path = CONFIGS_DIR / f"run_{run_id}.yaml" - + # Start with locked config as base import copy + run_config = copy.deepcopy(LOCKED_FIELDS) - + if "env" not in run_config: run_config["env"] = {} - + # Apply configurable fields for field_name, value in _current_config.items(): if value is not None and value != "": run_config["env"][field_name] = value - + # Set WandB settings wandb_project = _current_config.get("wandb_project", "atropos-tinker") if "tinker" not in run_config: run_config["tinker"] = {} run_config["tinker"]["wandb_project"] = wandb_project run_config["tinker"]["wandb_run_name"] = f"{_current_env}-{run_id}" - + if "wandb_name" in _current_config and _current_config["wandb_name"]: run_config["env"]["wandb_name"] = _current_config["wandb_name"] - + with open(config_path, "w") as f: yaml.dump(run_config, f, default_flow_style=False) - + # Create run state run_state = RunState( run_id=run_id, @@ -772,85 +822,91 @@ async def rl_start_training() -> str: wandb_project=wandb_project, wandb_run_name=f"{_current_env}-{run_id}", ) - + _active_runs[run_id] = run_state - + # Start training in background asyncio.create_task(_spawn_training_run(run_state, config_path)) - - return json.dumps({ - "run_id": run_id, - "status": "starting", - "environment": _current_env, - "config": _current_config, - "wandb_project": wandb_project, - "wandb_run_name": f"{_current_env}-{run_id}", - "config_path": str(config_path), - "logs": { - "api": str(LOGS_DIR / f"api_{run_id}.log"), - "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), - "env": str(LOGS_DIR / f"env_{run_id}.log"), + + return json.dumps( + { + "run_id": run_id, + "status": "starting", + "environment": _current_env, + "config": _current_config, + "wandb_project": wandb_project, + "wandb_run_name": f"{_current_env}-{run_id}", + "config_path": str(config_path), + "logs": { + "api": str(LOGS_DIR / f"api_{run_id}.log"), + "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), + "env": str(LOGS_DIR / f"env_{run_id}.log"), + }, + "message": "Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes).", }, - "message": "Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes).", - }, indent=2) + indent=2, + ) async def rl_check_status(run_id: str) -> str: """ Get status and metrics for a training run. - + RATE LIMITED: For long-running training, this function enforces a minimum 30-minute interval between checks for the same run_id. - + Args: run_id: The run ID returned by rl_start_training() - + Returns: JSON string with run status and metrics """ global _last_status_check - + # Check rate limiting now = time.time() if run_id in _last_status_check: elapsed = now - _last_status_check[run_id] if elapsed < MIN_STATUS_CHECK_INTERVAL: remaining = MIN_STATUS_CHECK_INTERVAL - elapsed - return json.dumps({ - "rate_limited": True, - "run_id": run_id, - "message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.", - "next_check_in_seconds": remaining, - }, indent=2) - + return json.dumps( + { + "rate_limited": True, + "run_id": run_id, + "message": f"Rate limited. Next check available in {remaining / 60:.0f} minutes.", + "next_check_in_seconds": remaining, + }, + indent=2, + ) + _last_status_check[run_id] = now - + if run_id not in _active_runs: - return json.dumps({ - "error": f"Run '{run_id}' not found", - "active_runs": list(_active_runs.keys()), - }, indent=2) - + return json.dumps( + { + "error": f"Run '{run_id}' not found", + "active_runs": list(_active_runs.keys()), + }, + indent=2, + ) + run_state = _active_runs[run_id] - + # Check process status processes = { "api": run_state.api_process.poll() if run_state.api_process else None, "trainer": run_state.trainer_process.poll() if run_state.trainer_process else None, "env": run_state.env_process.poll() if run_state.env_process else None, } - + running_time = time.time() - run_state.start_time if run_state.start_time else 0 - + result = { "run_id": run_id, "status": run_state.status, "environment": run_state.environment, "running_time_minutes": running_time / 60, - "processes": { - name: "running" if code is None else f"exited ({code})" - for name, code in processes.items() - }, + "processes": {name: "running" if code is None else f"exited ({code})" for name, code in processes.items()}, "wandb_project": run_state.wandb_project, "wandb_run_name": run_state.wandb_run_name, "logs": { @@ -859,17 +915,18 @@ async def rl_check_status(run_id: str) -> str: "env": str(LOGS_DIR / f"env_{run_id}.log"), }, } - + if run_state.error_message: result["error"] = run_state.error_message - + # Try to get WandB metrics if available try: import wandb + api = wandb.Api() runs = api.runs( f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", - filters={"display_name": run_state.wandb_run_name} + filters={"display_name": run_state.wandb_run_name}, ) if runs: wandb_run = runs[0] @@ -882,59 +939,71 @@ async def rl_check_status(run_id: str) -> str: } except Exception as e: result["wandb_error"] = str(e) - + return json.dumps(result, indent=2) async def rl_stop_training(run_id: str) -> str: """ Stop a running training job. - + Args: run_id: The run ID to stop - + Returns: JSON string with stop confirmation """ if run_id not in _active_runs: - return json.dumps({ - "error": f"Run '{run_id}' not found", - "active_runs": list(_active_runs.keys()), - }, indent=2) - + return json.dumps( + { + "error": f"Run '{run_id}' not found", + "active_runs": list(_active_runs.keys()), + }, + indent=2, + ) + run_state = _active_runs[run_id] - + if run_state.status not in ("running", "starting"): - return json.dumps({ - "message": f"Run '{run_id}' is not running (status: {run_state.status})", - }, indent=2) - + return json.dumps( + { + "message": f"Run '{run_id}' is not running (status: {run_state.status})", + }, + indent=2, + ) + _stop_training_run(run_state) - - return json.dumps({ - "message": f"Stopped training run '{run_id}'", - "run_id": run_id, - "status": run_state.status, - }, indent=2) + + return json.dumps( + { + "message": f"Stopped training run '{run_id}'", + "run_id": run_id, + "status": run_state.status, + }, + indent=2, + ) async def rl_get_results(run_id: str) -> str: """ Get final results and metrics for a training run. - + Args: run_id: The run ID to get results for - + Returns: JSON string with final results """ if run_id not in _active_runs: - return json.dumps({ - "error": f"Run '{run_id}' not found", - }, indent=2) - + return json.dumps( + { + "error": f"Run '{run_id}' not found", + }, + indent=2, + ) + run_state = _active_runs[run_id] - + result = { "run_id": run_id, "status": run_state.status, @@ -942,14 +1011,15 @@ async def rl_get_results(run_id: str) -> str: "wandb_project": run_state.wandb_project, "wandb_run_name": run_state.wandb_run_name, } - + # Get WandB metrics try: import wandb + api = wandb.Api() runs = api.runs( f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", - filters={"display_name": run_state.wandb_run_name} + filters={"display_name": run_state.wandb_run_name}, ) if runs: wandb_run = runs[0] @@ -958,30 +1028,35 @@ async def rl_get_results(run_id: str) -> str: result["history"] = [dict(row) for row in wandb_run.history(samples=10)] except Exception as e: result["wandb_error"] = str(e) - + return json.dumps(result, indent=2) async def rl_list_runs() -> str: """ List all training runs (active and completed). - + Returns: JSON string with list of runs and their status """ runs = [] for run_id, run_state in _active_runs.items(): - runs.append({ - "run_id": run_id, - "environment": run_state.environment, - "status": run_state.status, - "wandb_run_name": run_state.wandb_run_name, - }) - - return json.dumps({ - "runs": runs, - "count": len(runs), - }, indent=2) + runs.append( + { + "run_id": run_id, + "environment": run_state.environment, + "status": run_state.status, + "wandb_run_name": run_state.wandb_run_name, + } + ) + + return json.dumps( + { + "runs": runs, + "count": len(runs), + }, + indent=2, + ) # ============================================================================ @@ -997,63 +1072,72 @@ TEST_MODELS = [ ] # Default test parameters - quick but representative -DEFAULT_NUM_STEPS = 3 # Number of steps (items) to test -DEFAULT_GROUP_SIZE = 16 # Completions per item (like training) +DEFAULT_NUM_STEPS = 3 # Number of steps (items) to test +DEFAULT_GROUP_SIZE = 16 # Completions per item (like training) async def rl_test_inference( num_steps: int = DEFAULT_NUM_STEPS, group_size: int = DEFAULT_GROUP_SIZE, - models: Optional[List[str]] = None, + models: list[str] | None = None, ) -> str: """ Quick inference test for any environment using Atropos's `process` mode. - + Runs a few steps of inference + scoring to validate: - Environment loads correctly - Prompt construction works - Inference parsing is robust (tested with multiple model scales) - Verifier/scoring logic works - + Default: 3 steps × 16 completions = 48 total rollouts per model. Tests 3 models = 144 total rollouts. Quick sanity check. - + Test models (varying intelligence levels for robustness): - qwen/qwen3-8b (small) - zhipu-ai/glm-4-flash (medium) - minimax/minimax-m1 (large) - + Args: num_steps: Steps to run (default: 3, max recommended for testing) group_size: Completions per step (default: 16, like training) models: Optional model IDs to test. If None, uses all 3 test models. - + Returns: JSON with results per model: steps_tested, accuracy, scores """ if not _current_env: - return json.dumps({ - "error": "No environment selected. Use rl_select_environment(name) first.", - }, indent=2) - + return json.dumps( + { + "error": "No environment selected. Use rl_select_environment(name) first.", + }, + indent=2, + ) + api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: - return json.dumps({ - "error": "OPENROUTER_API_KEY not set. Required for inference testing.", - }, indent=2) - + return json.dumps( + { + "error": "OPENROUTER_API_KEY not set. Required for inference testing.", + }, + indent=2, + ) + # Find environment info env_info = None for env in _environments: if env.name == _current_env: env_info = env break - + if not env_info: - return json.dumps({ - "error": f"Environment '{_current_env}' not found", - }, indent=2) - + return json.dumps( + { + "error": f"Environment '{_current_env}' not found", + }, + indent=2, + ) + # Determine which models to test if models: test_models = [m for m in TEST_MODELS if m["id"] in models] @@ -1061,11 +1145,11 @@ async def rl_test_inference( test_models = [{"id": m, "name": m, "scale": "custom"} for m in models] else: test_models = TEST_MODELS - + # Calculate total rollouts for logging total_rollouts_per_model = num_steps * group_size total_rollouts = total_rollouts_per_model * len(test_models) - + results = { "environment": _current_env, "environment_file": env_info.file_path, @@ -1077,52 +1161,68 @@ async def rl_test_inference( }, "models_tested": [], } - + # Create output directory for test results test_output_dir = LOGS_DIR / "inference_tests" test_output_dir.mkdir(exist_ok=True) - + for model_info in test_models: model_id = model_info["id"] model_safe_name = model_id.replace("/", "_") - - print(f"\n{'='*60}") + + print(f"\n{'=' * 60}") print(f"Testing with {model_info['name']} ({model_id})") - print(f"{'='*60}") - + print(f"{'=' * 60}") + # Output file for this test run output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl" - + # Generate unique run ID for wandb test_run_id = str(uuid.uuid4())[:8] wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}" - + # Build the process command using Atropos's built-in CLI # This runs the environment's actual code with OpenRouter as the inference backend # We pass our locked settings + test-specific overrides via CLI args cmd = [ - sys.executable, env_info.file_path, "process", + sys.executable, + env_info.file_path, + "process", # Test-specific overrides - "--env.total_steps", str(num_steps), - "--env.group_size", str(group_size), - "--env.use_wandb", "true", # Enable wandb for test tracking - "--env.wandb_name", wandb_run_name, - "--env.data_path_to_save_groups", str(output_file), + "--env.total_steps", + str(num_steps), + "--env.group_size", + str(group_size), + "--env.use_wandb", + "true", # Enable wandb for test tracking + "--env.wandb_name", + wandb_run_name, + "--env.data_path_to_save_groups", + str(output_file), # Use locked settings from our config - "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"], - "--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]), - "--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]), - "--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]), + "--env.tokenizer_name", + LOCKED_FIELDS["env"]["tokenizer_name"], + "--env.max_token_length", + str(LOCKED_FIELDS["env"]["max_token_length"]), + "--env.max_num_workers", + str(LOCKED_FIELDS["env"]["max_num_workers"]), + "--env.max_batches_offpolicy", + str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]), # OpenRouter config for inference testing # IMPORTANT: Use server_type=openai for OpenRouter (not sglang) # sglang is only for actual training with Tinker's inference server - "--openai.base_url", "https://openrouter.ai/api/v1", - "--openai.api_key", api_key, - "--openai.model_name", model_id, - "--openai.server_type", "openai", # OpenRouter is OpenAI-compatible - "--openai.health_check", "false", # OpenRouter doesn't have health endpoint + "--openai.base_url", + "https://openrouter.ai/api/v1", + "--openai.api_key", + api_key, + "--openai.model_name", + model_id, + "--openai.server_type", + "openai", # OpenRouter is OpenAI-compatible + "--openai.health_check", + "false", # OpenRouter doesn't have health endpoint ] - + # Debug: Print the full command cmd_str = " ".join(str(c) for c in cmd) # Hide API key in printed output @@ -1131,7 +1231,7 @@ async def rl_test_inference( print(f"Working dir: {TINKER_ATROPOS_ROOT}") print(f"WandB run: {wandb_run_name}") print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts") - + model_results = { "model": model_id, "name": model_info["name"], @@ -1143,7 +1243,7 @@ async def rl_test_inference( "total_completions": 0, "correct_completions": 0, } - + try: # Run the process command with real-time output streaming process = await asyncio.create_subprocess_exec( @@ -1152,12 +1252,12 @@ async def rl_test_inference( stderr=asyncio.subprocess.PIPE, cwd=str(TINKER_ATROPOS_ROOT), ) - + # Stream output in real-time while collecting for logs stdout_lines = [] stderr_lines = [] log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" - + async def read_stream(stream, lines_list, prefix=""): """Read stream line by line and print in real-time.""" while True: @@ -1167,9 +1267,11 @@ async def rl_test_inference( decoded = line.decode().rstrip() lines_list.append(decoded) # Print progress-related lines in real-time - if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']): + if any( + kw in decoded.lower() for kw in ["processing", "group", "step", "progress", "%", "completed"] + ): print(f" {prefix}{decoded}") - + # Read both streams concurrently with timeout try: await asyncio.wait_for( @@ -1179,30 +1281,30 @@ async def rl_test_inference( ), timeout=600, # 10 minute timeout per model ) - except asyncio.TimeoutError: + except TimeoutError: process.kill() raise - + await process.wait() - + # Combine output for logging stdout_text = "\n".join(stdout_lines) stderr_text = "\n".join(stderr_lines) - + # Write logs to files for inspection outside CLI with open(log_file, "w") as f: f.write(f"Command: {cmd_display}\n") f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n") f.write(f"Return code: {process.returncode}\n") - f.write(f"\n{'='*60}\n") - f.write(f"STDOUT:\n{'='*60}\n") + f.write(f"\n{'=' * 60}\n") + f.write(f"STDOUT:\n{'=' * 60}\n") f.write(stdout_text or "(empty)\n") - f.write(f"\n{'='*60}\n") - f.write(f"STDERR:\n{'='*60}\n") + f.write(f"\n{'=' * 60}\n") + f.write(f"STDERR:\n{'=' * 60}\n") f.write(stderr_text or "(empty)\n") - + print(f" Log file: {log_file}") - + if process.returncode != 0: model_results["error"] = f"Process exited with code {process.returncode}" model_results["stderr"] = stderr_text[-1000:] @@ -1211,18 +1313,18 @@ async def rl_test_inference( print(f"\n ❌ Error: {model_results['error']}") # Print last few lines of stderr for debugging if stderr_lines: - print(f" Last errors:") + print(" Last errors:") for line in stderr_lines[-5:]: print(f" {line}") else: - print(f"\n ✅ Process completed successfully") + print("\n ✅ Process completed successfully") print(f" Output file: {output_file}") print(f" File exists: {output_file.exists()}") - + # Parse the output JSONL file if output_file.exists(): # Read JSONL file (one JSON object per line = one step) - with open(output_file, "r") as f: + with open(output_file) as f: for line in f: line = line.strip() if not line: @@ -1234,27 +1336,29 @@ async def rl_test_inference( model_results["total_completions"] += len(scores) correct = sum(1 for s in scores if s > 0) model_results["correct_completions"] += correct - - model_results["steps"].append({ - "step": model_results["steps_tested"], - "completions": len(scores), - "correct": correct, - "scores": scores, - }) + + model_results["steps"].append( + { + "step": model_results["steps_tested"], + "completions": len(scores), + "correct": correct, + "scores": scores, + } + ) except json.JSONDecodeError: continue - + print(f" Completed {model_results['steps_tested']} steps") else: model_results["error"] = f"Output file not created: {output_file}" - - except asyncio.TimeoutError: + + except TimeoutError: model_results["error"] = "Process timed out after 10 minutes" - print(f" Timeout!") + print(" Timeout!") except Exception as e: model_results["error"] = str(e) print(f" Error: {e}") - + # Calculate stats if model_results["total_completions"] > 0: model_results["accuracy"] = round( @@ -1262,37 +1366,35 @@ async def rl_test_inference( ) else: model_results["accuracy"] = 0 - + if model_results["steps_tested"] > 0: steps_with_correct = sum(1 for s in model_results["steps"] if s.get("correct", 0) > 0) model_results["steps_with_correct"] = steps_with_correct - model_results["step_success_rate"] = round( - steps_with_correct / model_results["steps_tested"], 3 - ) + model_results["step_success_rate"] = round(steps_with_correct / model_results["steps_tested"], 3) else: model_results["steps_with_correct"] = 0 model_results["step_success_rate"] = 0 - + print(f" Results: {model_results['correct_completions']}/{model_results['total_completions']} correct") print(f" Accuracy: {model_results['accuracy']:.1%}") - + results["models_tested"].append(model_results) - + # Overall summary working_models = [m for m in results["models_tested"] if m.get("steps_tested", 0) > 0] - + results["summary"] = { "steps_requested": num_steps, "models_tested": len(test_models), "models_succeeded": len(working_models), "best_model": max(working_models, key=lambda x: x.get("accuracy", 0))["model"] if working_models else None, - "avg_accuracy": round( - sum(m.get("accuracy", 0) for m in working_models) / len(working_models), 3 - ) if working_models else 0, + "avg_accuracy": round(sum(m.get("accuracy", 0) for m in working_models) / len(working_models), 3) + if working_models + else 0, "environment_working": len(working_models) > 0, "output_directory": str(test_output_dir), } - + return json.dumps(results, indent=2) @@ -1300,10 +1402,11 @@ async def rl_test_inference( # Requirements Check # ============================================================================ + def check_rl_python_version() -> bool: """ Check if Python version meets the minimum for RL tools. - + tinker-atropos depends on the 'tinker' package which requires Python >= 3.11. """ return sys.version_info >= (3, 11) @@ -1312,7 +1415,7 @@ def check_rl_python_version() -> bool: def check_rl_api_keys() -> bool: """ Check if required API keys and Python version are available. - + RL training requires: - Python >= 3.11 (tinker package requirement) - TINKER_API_KEY for the Tinker training API @@ -1325,7 +1428,7 @@ def check_rl_api_keys() -> bool: return bool(tinker_key) and bool(wandb_key) -def get_missing_keys() -> List[str]: +def get_missing_keys() -> list[str]: """ Get list of missing requirements for RL tools (API keys and Python version). """ @@ -1344,37 +1447,196 @@ def get_missing_keys() -> List[str]: # --------------------------------------------------------------------------- from tools.registry import registry -RL_LIST_ENVIRONMENTS_SCHEMA = {"name": "rl_list_environments", "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", "parameters": {"type": "object", "properties": {}, "required": []}} -RL_SELECT_ENVIRONMENT_SCHEMA = {"name": "rl_select_environment", "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", "parameters": {"type": "object", "properties": {"name": {"type": "string", "description": "Name of the environment to select (from rl_list_environments)"}}, "required": ["name"]}} -RL_GET_CURRENT_CONFIG_SCHEMA = {"name": "rl_get_current_config", "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", "parameters": {"type": "object", "properties": {}, "required": []}} -RL_EDIT_CONFIG_SCHEMA = {"name": "rl_edit_config", "description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.", "parameters": {"type": "object", "properties": {"field": {"type": "string", "description": "Name of the field to update (get available fields from rl_get_current_config)"}, "value": {"description": "New value for the field"}}, "required": ["field", "value"]}} -RL_START_TRAINING_SCHEMA = {"name": "rl_start_training", "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.", "parameters": {"type": "object", "properties": {}, "required": []}} -RL_CHECK_STATUS_SCHEMA = {"name": "rl_check_status", "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID from rl_start_training()"}}, "required": ["run_id"]}} -RL_STOP_TRAINING_SCHEMA = {"name": "rl_stop_training", "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to stop"}}, "required": ["run_id"]}} -RL_GET_RESULTS_SCHEMA = {"name": "rl_get_results", "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to get results for"}}, "required": ["run_id"]}} -RL_LIST_RUNS_SCHEMA = {"name": "rl_list_runs", "description": "List all training runs (active and completed) with their status.", "parameters": {"type": "object", "properties": {}, "required": []}} -RL_TEST_INFERENCE_SCHEMA = {"name": "rl_test_inference", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": {"type": "object", "properties": {"num_steps": {"type": "integer", "description": "Number of steps to run (default: 3, recommended max for testing)", "default": 3}, "group_size": {"type": "integer", "description": "Completions per step (default: 16, like training)", "default": 16}, "models": {"type": "array", "items": {"type": "string"}, "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.5"}}, "required": []}} +RL_LIST_ENVIRONMENTS_SCHEMA = { + "name": "rl_list_environments", + "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", + "parameters": {"type": "object", "properties": {}, "required": []}, +} +RL_SELECT_ENVIRONMENT_SCHEMA = { + "name": "rl_select_environment", + "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Name of the environment to select (from rl_list_environments)"} + }, + "required": ["name"], + }, +} +RL_GET_CURRENT_CONFIG_SCHEMA = { + "name": "rl_get_current_config", + "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", + "parameters": {"type": "object", "properties": {}, "required": []}, +} +RL_EDIT_CONFIG_SCHEMA = { + "name": "rl_edit_config", + "description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.", + "parameters": { + "type": "object", + "properties": { + "field": { + "type": "string", + "description": "Name of the field to update (get available fields from rl_get_current_config)", + }, + "value": {"description": "New value for the field"}, + }, + "required": ["field", "value"], + }, +} +RL_START_TRAINING_SCHEMA = { + "name": "rl_start_training", + "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.", + "parameters": {"type": "object", "properties": {}, "required": []}, +} +RL_CHECK_STATUS_SCHEMA = { + "name": "rl_check_status", + "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", + "parameters": { + "type": "object", + "properties": {"run_id": {"type": "string", "description": "The run ID from rl_start_training()"}}, + "required": ["run_id"], + }, +} +RL_STOP_TRAINING_SCHEMA = { + "name": "rl_stop_training", + "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", + "parameters": { + "type": "object", + "properties": {"run_id": {"type": "string", "description": "The run ID to stop"}}, + "required": ["run_id"], + }, +} +RL_GET_RESULTS_SCHEMA = { + "name": "rl_get_results", + "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", + "parameters": { + "type": "object", + "properties": {"run_id": {"type": "string", "description": "The run ID to get results for"}}, + "required": ["run_id"], + }, +} +RL_LIST_RUNS_SCHEMA = { + "name": "rl_list_runs", + "description": "List all training runs (active and completed) with their status.", + "parameters": {"type": "object", "properties": {}, "required": []}, +} +RL_TEST_INFERENCE_SCHEMA = { + "name": "rl_test_inference", + "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", + "parameters": { + "type": "object", + "properties": { + "num_steps": { + "type": "integer", + "description": "Number of steps to run (default: 3, recommended max for testing)", + "default": 3, + }, + "group_size": { + "type": "integer", + "description": "Completions per step (default: 16, like training)", + "default": 16, + }, + "models": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.5", + }, + }, + "required": [], + }, +} _rl_env = ["TINKER_API_KEY", "WANDB_API_KEY"] -registry.register(name="rl_list_environments", toolset="rl", schema=RL_LIST_ENVIRONMENTS_SCHEMA, - handler=lambda args, **kw: rl_list_environments(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_select_environment", toolset="rl", schema=RL_SELECT_ENVIRONMENT_SCHEMA, - handler=lambda args, **kw: rl_select_environment(name=args.get("name", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_get_current_config", toolset="rl", schema=RL_GET_CURRENT_CONFIG_SCHEMA, - handler=lambda args, **kw: rl_get_current_config(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_edit_config", toolset="rl", schema=RL_EDIT_CONFIG_SCHEMA, - handler=lambda args, **kw: rl_edit_config(field=args.get("field", ""), value=args.get("value")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_start_training", toolset="rl", schema=RL_START_TRAINING_SCHEMA, - handler=lambda args, **kw: rl_start_training(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_check_status", toolset="rl", schema=RL_CHECK_STATUS_SCHEMA, - handler=lambda args, **kw: rl_check_status(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_stop_training", toolset="rl", schema=RL_STOP_TRAINING_SCHEMA, - handler=lambda args, **kw: rl_stop_training(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_get_results", toolset="rl", schema=RL_GET_RESULTS_SCHEMA, - handler=lambda args, **kw: rl_get_results(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_list_runs", toolset="rl", schema=RL_LIST_RUNS_SCHEMA, - handler=lambda args, **kw: rl_list_runs(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) -registry.register(name="rl_test_inference", toolset="rl", schema=RL_TEST_INFERENCE_SCHEMA, - handler=lambda args, **kw: rl_test_inference(num_steps=args.get("num_steps", 3), group_size=args.get("group_size", 16), models=args.get("models")), - check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) +registry.register( + name="rl_list_environments", + toolset="rl", + schema=RL_LIST_ENVIRONMENTS_SCHEMA, + handler=lambda args, **kw: rl_list_environments(), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_select_environment", + toolset="rl", + schema=RL_SELECT_ENVIRONMENT_SCHEMA, + handler=lambda args, **kw: rl_select_environment(name=args.get("name", "")), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_get_current_config", + toolset="rl", + schema=RL_GET_CURRENT_CONFIG_SCHEMA, + handler=lambda args, **kw: rl_get_current_config(), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_edit_config", + toolset="rl", + schema=RL_EDIT_CONFIG_SCHEMA, + handler=lambda args, **kw: rl_edit_config(field=args.get("field", ""), value=args.get("value")), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_start_training", + toolset="rl", + schema=RL_START_TRAINING_SCHEMA, + handler=lambda args, **kw: rl_start_training(), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_check_status", + toolset="rl", + schema=RL_CHECK_STATUS_SCHEMA, + handler=lambda args, **kw: rl_check_status(run_id=args.get("run_id", "")), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_stop_training", + toolset="rl", + schema=RL_STOP_TRAINING_SCHEMA, + handler=lambda args, **kw: rl_stop_training(run_id=args.get("run_id", "")), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_get_results", + toolset="rl", + schema=RL_GET_RESULTS_SCHEMA, + handler=lambda args, **kw: rl_get_results(run_id=args.get("run_id", "")), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_list_runs", + toolset="rl", + schema=RL_LIST_RUNS_SCHEMA, + handler=lambda args, **kw: rl_list_runs(), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) +registry.register( + name="rl_test_inference", + toolset="rl", + schema=RL_TEST_INFERENCE_SCHEMA, + handler=lambda args, **kw: rl_test_inference( + num_steps=args.get("num_steps", 3), group_size=args.get("group_size", 16), models=args.get("models") + ), + check_fn=check_rl_api_keys, + requires_env=_rl_env, + is_async=True, +) diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 8f5dbb61cd..13db286ab7 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -29,19 +29,16 @@ SEND_MESSAGE_SCHEMA = { "action": { "type": "string", "enum": ["send", "list"], - "description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms." + "description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms.", }, "target": { "type": "string", - "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'" + "description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'", }, - "message": { - "type": "string", - "description": "The message text to send" - } + "message": {"type": "string", "description": "The message text to send"}, }, - "required": [] - } + "required": [], + }, } @@ -59,6 +56,7 @@ def _handle_list(): """Return formatted list of available messaging targets.""" try: from gateway.channel_directory import format_directory_for_display + return json.dumps({"targets": format_directory_for_display()}) except Exception as e: return json.dumps({"error": f"Failed to load channel directory: {e}"}) @@ -79,26 +77,30 @@ def _handle_send(args): if chat_id and not chat_id.lstrip("-").isdigit(): try: from gateway.channel_directory import resolve_channel_name + resolved = resolve_channel_name(platform_name, chat_id) if resolved: chat_id = resolved else: - return json.dumps({ - "error": f"Could not resolve '{chat_id}' on {platform_name}. " - f"Use send_message(action='list') to see available targets." - }) + return json.dumps( + { + "error": f"Could not resolve '{chat_id}' on {platform_name}. " + f"Use send_message(action='list') to see available targets." + } + ) except Exception: - return json.dumps({ - "error": f"Could not resolve '{chat_id}' on {platform_name}. " - f"Try using a numeric channel ID instead." - }) + return json.dumps( + {"error": f"Could not resolve '{chat_id}' on {platform_name}. Try using a numeric channel ID instead."} + ) from tools.interrupt import is_interrupted + if is_interrupted(): return json.dumps({"error": "Interrupted"}) try: - from gateway.config import load_gateway_config, Platform + from gateway.config import Platform, load_gateway_config + config = load_gateway_config() except Exception as e: return json.dumps({"error": f"Failed to load gateway config: {e}"}) @@ -117,7 +119,11 @@ def _handle_send(args): pconfig = config.platforms.get(platform) if not pconfig or not pconfig.enabled: - return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."}) + return json.dumps( + { + "error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables." + } + ) used_home_channel = False if not chat_id: @@ -126,14 +132,17 @@ def _handle_send(args): chat_id = home.chat_id used_home_channel = True else: - return json.dumps({ - "error": f"No home channel set for {platform_name} to determine where to send the message. " - f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', " - f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL " - }) + return json.dumps( + { + "error": f"No home channel set for {platform_name} to determine where to send the message. " + f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', " + f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL " + } + ) try: from model_tools import _run_async + result = _run_async(_send_to_platform(platform, pconfig, chat_id, message)) if used_home_channel and isinstance(result, dict) and result.get("success"): result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})" @@ -142,6 +151,7 @@ def _handle_send(args): if isinstance(result, dict) and result.get("success"): try: from gateway.mirror import mirror_to_session + source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli") if mirror_to_session(platform_name, chat_id, message, source_label=source_label): result["mirrored"] = True @@ -156,6 +166,7 @@ def _handle_send(args): async def _send_to_platform(platform, pconfig, chat_id, message): """Route a message to the appropriate platform sender.""" from gateway.config import Platform + if platform == Platform.TELEGRAM: return await _send_telegram(pconfig.token, chat_id, message) elif platform == Platform.DISCORD: @@ -171,6 +182,7 @@ async def _send_telegram(token, chat_id, message): """Send via Telegram Bot API (one-shot, no polling needed).""" try: from telegram import Bot + bot = Bot(token=token) msg = await bot.send_message(chat_id=int(chat_id), text=message) return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)} @@ -189,7 +201,7 @@ async def _send_discord(token, chat_id, message): try: url = f"https://discord.com/api/v10/channels/{chat_id}/messages" headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"} - chunks = [message[i:i+2000] for i in range(0, len(message), 2000)] + chunks = [message[i : i + 2000] for i in range(0, len(message), 2000)] message_ids = [] async with aiohttp.ClientSession() as session: for chunk in chunks: @@ -266,6 +278,7 @@ def _check_send_message(): return True try: from gateway.status import is_gateway_running + return is_gateway_running() except Exception: return False diff --git a/tools/session_search_tool.py b/tools/session_search_tool.py index 4bf88cbf0d..0237d007cb 100644 --- a/tools/session_search_tool.py +++ b/tools/session_search_tool.py @@ -18,11 +18,8 @@ Flow: import asyncio import concurrent.futures import json -import os import logging -from typing import Dict, Any, List, Optional, Union - -from openai import AsyncOpenAI, OpenAI +from typing import Any from agent.auxiliary_client import get_async_text_auxiliary_client @@ -33,7 +30,7 @@ MAX_SESSION_CHARS = 100_000 MAX_SUMMARY_TOKENS = 10000 -def _format_timestamp(ts: Union[int, float, str, None]) -> str: +def _format_timestamp(ts: int | float | str | None) -> str: """Convert a Unix timestamp (float/int) or ISO string to a human-readable date. Returns "unknown" for None, str(ts) if conversion fails. @@ -43,11 +40,13 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str: try: if isinstance(ts, (int, float)): from datetime import datetime + dt = datetime.fromtimestamp(ts) return dt.strftime("%B %d, %Y at %I:%M %p") if isinstance(ts, str): if ts.replace(".", "").replace("-", "").isdigit(): from datetime import datetime + dt = datetime.fromtimestamp(float(ts)) return dt.strftime("%B %d, %Y at %I:%M %p") return ts @@ -59,7 +58,7 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str: return str(ts) -def _format_conversation(messages: List[Dict[str, Any]]) -> str: +def _format_conversation(messages: list[dict[str, Any]]) -> str: """Format session messages into a readable transcript for summarization.""" parts = [] for msg in messages: @@ -93,9 +92,7 @@ def _format_conversation(messages: List[Dict[str, Any]]) -> str: return "\n\n".join(parts) -def _truncate_around_matches( - full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS -) -> str: +def _truncate_around_matches(full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS) -> str: """ Truncate a conversation transcript to max_chars, centered around where the query terms appear. Keeps content near matches, trims the edges. @@ -129,9 +126,7 @@ def _truncate_around_matches( return prefix + truncated + suffix -async def _summarize_session( - conversation_text: str, query: str, session_meta: Dict[str, Any] -) -> Optional[str]: +async def _summarize_session(conversation_text: str, query: str, session_meta: dict[str, Any]) -> str | None: """Summarize a single session conversation focused on the search query.""" system_prompt = ( "You are reviewing a past conversation transcript to help recall what happened. " @@ -163,7 +158,8 @@ async def _summarize_session( max_retries = 3 for attempt in range(max_retries): try: - from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param + from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body + _extra = get_auxiliary_extra_body() response = await _async_aux_client.chat.completions.create( model=_SUMMARIZER_MODEL, @@ -221,13 +217,16 @@ def session_search( ) if not raw_results: - return json.dumps({ - "success": True, - "query": query, - "results": [], - "count": 0, - "message": "No matching sessions found.", - }, ensure_ascii=False) + return json.dumps( + { + "success": True, + "query": query, + "results": [], + "count": 0, + "message": "No matching sessions found.", + }, + ensure_ascii=False, + ) # Resolve child sessions to their parent — delegation stores detailed # content in child sessions, but the user's conversation is the parent. @@ -283,12 +282,9 @@ def session_search( logging.warning(f"Failed to prepare session {session_id}: {e}") # Summarize all sessions in parallel - async def _summarize_all() -> List[Union[str, Exception]]: + async def _summarize_all() -> list[str | Exception]: """Summarize all sessions in parallel.""" - coros = [ - _summarize_session(text, query, meta) - for _, _, text, meta in tasks - ] + coros = [_summarize_session(text, query, meta) for _, _, text, meta in tasks] return await asyncio.gather(*coros, return_exceptions=True) try: @@ -300,10 +296,13 @@ def session_search( results = asyncio.run(_summarize_all()) except concurrent.futures.TimeoutError: logging.warning("Session summarization timed out after 60 seconds") - return json.dumps({ - "success": False, - "error": "Session summarization timed out. Try a more specific query or reduce the limit.", - }, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "Session summarization timed out. Try a more specific query or reduce the limit.", + }, + ensure_ascii=False, + ) summaries = [] for (session_id, match_info, _, _), result in zip(tasks, results): @@ -311,21 +310,26 @@ def session_search( logging.warning(f"Failed to summarize session {session_id}: {result}") continue if result: - summaries.append({ - "session_id": session_id, - "when": _format_timestamp(match_info.get("session_started")), - "source": match_info.get("source", "unknown"), - "model": match_info.get("model"), - "summary": result, - }) + summaries.append( + { + "session_id": session_id, + "when": _format_timestamp(match_info.get("session_started")), + "source": match_info.get("source", "unknown"), + "model": match_info.get("model"), + "summary": result, + } + ) - return json.dumps({ - "success": True, - "query": query, - "results": summaries, - "count": len(summaries), - "sessions_searched": len(seen_sessions), - }, ensure_ascii=False) + return json.dumps( + { + "success": True, + "query": query, + "results": summaries, + "count": len(summaries), + "sessions_searched": len(seen_sessions), + }, + ensure_ascii=False, + ) except Exception as e: return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False) @@ -337,6 +341,7 @@ def check_session_search_requirements() -> bool: return False try: from hermes_state import DEFAULT_DB_PATH + return DEFAULT_DB_PATH.parent.exists() except ImportError: return False @@ -356,7 +361,7 @@ SESSION_SEARCH_SCHEMA = { "Don't hesitate to search -- it's fast and cheap. Better to search and confirm " "than to guess or ask the user to repeat themselves.\n\n" "Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), " - "phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). " + 'phrases for exact match ("docker networking"), boolean (python NOT java), prefix (deploy*). ' "IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses " "sessions that only mention some terms. If a broad OR query returns nothing, try individual " "keyword searches in parallel. Returns summaries of the top matching sessions." @@ -395,6 +400,7 @@ registry.register( role_filter=args.get("role_filter"), limit=args.get("limit", 3), db=kw.get("db"), - current_session_id=kw.get("current_session_id")), + current_session_id=kw.get("current_session_id"), + ), check_fn=check_session_search_requirements, ) diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 29bf1be5c5..265dec4a48 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -38,20 +38,21 @@ import os import re import shutil from pathlib import Path -from typing import Dict, Any, Optional +from typing import Any logger = logging.getLogger(__name__) # Import security scanner — agent-created skills get the same scrutiny as # community hub installs. try: - from tools.skills_guard import scan_skill, should_allow_install, format_scan_report + from tools.skills_guard import format_scan_report, scan_skill, should_allow_install + _GUARD_AVAILABLE = True except ImportError: _GUARD_AVAILABLE = False -def _security_scan_skill(skill_dir: Path) -> Optional[str]: +def _security_scan_skill(skill_dir: Path) -> str | None: """Scan a skill directory after write. Returns error string if blocked, else None.""" if not _GUARD_AVAILABLE: return None @@ -65,8 +66,8 @@ def _security_scan_skill(skill_dir: Path) -> Optional[str]: logger.warning("Security scan failed for %s: %s", skill_dir, e) return None -import yaml +import yaml # All skills live in ~/.hermes/skills/ (single source of truth) HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) @@ -76,7 +77,7 @@ MAX_NAME_LENGTH = 64 MAX_DESCRIPTION_LENGTH = 1024 # Characters allowed in skill names (filesystem-safe, URL-friendly) -VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$') +VALID_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9._-]*$") # Subdirectories allowed for write_file/remove_file ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"} @@ -91,7 +92,8 @@ def check_skill_manage_requirements() -> bool: # Validation helpers # ============================================================================= -def _validate_name(name: str) -> Optional[str]: + +def _validate_name(name: str) -> str | None: """Validate a skill name. Returns error message or None if valid.""" if not name: return "Skill name is required." @@ -105,7 +107,7 @@ def _validate_name(name: str) -> Optional[str]: return None -def _validate_frontmatter(content: str) -> Optional[str]: +def _validate_frontmatter(content: str) -> str | None: """ Validate that SKILL.md content has proper frontmatter with required fields. Returns error message or None if valid. @@ -116,11 +118,11 @@ def _validate_frontmatter(content: str) -> Optional[str]: if not content.startswith("---"): return "SKILL.md must start with YAML frontmatter (---). See existing skills for format." - end_match = re.search(r'\n---\s*\n', content[3:]) + end_match = re.search(r"\n---\s*\n", content[3:]) if not end_match: return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line." - yaml_content = content[3:end_match.start() + 3] + yaml_content = content[3 : end_match.start() + 3] try: parsed = yaml.safe_load(yaml_content) @@ -137,7 +139,7 @@ def _validate_frontmatter(content: str) -> Optional[str]: if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH: return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters." - body = content[end_match.end() + 3:].strip() + body = content[end_match.end() + 3 :].strip() if not body: return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)." @@ -151,7 +153,7 @@ def _resolve_skill_dir(name: str, category: str = None) -> Path: return SKILLS_DIR / name -def _find_skill(name: str) -> Optional[Dict[str, Any]]: +def _find_skill(name: str) -> dict[str, Any] | None: """ Find a skill by name in ~/.hermes/skills/. Returns {"path": Path} or None. @@ -164,7 +166,7 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]: return None -def _validate_file_path(file_path: str) -> Optional[str]: +def _validate_file_path(file_path: str) -> str | None: """ Validate a file path for write_file/remove_file. Must be under an allowed subdirectory and not escape the skill dir. @@ -194,7 +196,8 @@ def _validate_file_path(file_path: str) -> Optional[str]: # Core actions # ============================================================================= -def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]: + +def _create_skill(name: str, content: str, category: str = None) -> dict[str, Any]: """Create a new user skill with SKILL.md content.""" # Validate name err = _validate_name(name) @@ -209,10 +212,7 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An # Check for name collisions across all directories existing = _find_skill(name) if existing: - return { - "success": False, - "error": f"A skill named '{name}' already exists at {existing['path']}." - } + return {"success": False, "error": f"A skill named '{name}' already exists at {existing['path']}."} # Create the skill directory skill_dir = _resolve_skill_dir(name, category) @@ -238,12 +238,12 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An result["category"] = category result["hint"] = ( "To add reference files, templates, or scripts, use " - "skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name) + f"skill_manage(action='write_file', name='{name}', file_path='references/example.md', file_content='...')" ) return result -def _edit_skill(name: str, content: str) -> Dict[str, Any]: +def _edit_skill(name: str, content: str) -> dict[str, Any]: """Replace the SKILL.md of any existing skill (full rewrite).""" err = _validate_frontmatter(content) if err: @@ -278,7 +278,7 @@ def _patch_skill( new_string: str, file_path: str = None, replace_all: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Targeted find-and-replace within a skill file. Defaults to SKILL.md. Use file_path to patch a supporting file instead. @@ -287,7 +287,10 @@ def _patch_skill( if not old_string: return {"success": False, "error": "old_string is required for 'patch'."} if new_string is None: - return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."} + return { + "success": False, + "error": "new_string is required for 'patch'. Use an empty string to delete matched text.", + } existing = _find_skill(name) if not existing: @@ -357,7 +360,7 @@ def _patch_skill( } -def _delete_skill(name: str) -> Dict[str, Any]: +def _delete_skill(name: str) -> dict[str, Any]: """Delete a skill.""" existing = _find_skill(name) if not existing: @@ -377,7 +380,7 @@ def _delete_skill(name: str) -> Dict[str, Any]: } -def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]: +def _write_file(name: str, file_path: str, file_content: str) -> dict[str, Any]: """Add or overwrite a supporting file within any skill directory.""" err = _validate_file_path(file_path) if err: @@ -412,7 +415,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]: } -def _remove_file(name: str, file_path: str) -> Dict[str, Any]: +def _remove_file(name: str, file_path: str) -> dict[str, Any]: """Remove a supporting file from any skill directory.""" err = _validate_file_path(file_path) if err: @@ -456,6 +459,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]: # Main entry point # ============================================================================= + def skill_manage( action: str, name: str, @@ -474,19 +478,37 @@ def skill_manage( """ if action == "create": if not content: - return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body).", + }, + ensure_ascii=False, + ) result = _create_skill(name, content, category) elif action == "edit": if not content: - return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, + ensure_ascii=False, + ) result = _edit_skill(name, content) elif action == "patch": if not old_string: - return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, + ensure_ascii=False, + ) if new_string is None: - return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "new_string is required for 'patch'. Use empty string to delete matched text.", + }, + ensure_ascii=False, + ) result = _patch_skill(name, old_string, new_string, file_path, replace_all) elif action == "delete": @@ -494,18 +516,31 @@ def skill_manage( elif action == "write_file": if not file_path: - return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'", + }, + ensure_ascii=False, + ) if file_content is None: - return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False + ) result = _write_file(name, file_path, file_content) elif action == "remove_file": if not file_path: - return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False + ) result = _remove_file(name, file_path) else: - result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"} + result = { + "success": False, + "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file", + } return json.dumps(result, ensure_ascii=False) @@ -540,14 +575,14 @@ SKILL_MANAGE_SCHEMA = { "action": { "type": "string", "enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"], - "description": "The action to perform." + "description": "The action to perform.", }, "name": { "type": "string", "description": ( "Skill name (lowercase, hyphens/underscores, max 64 chars). " "Must match an existing skill for patch/edit/delete/write_file/remove_file." - ) + ), }, "content": { "type": "string", @@ -555,7 +590,7 @@ SKILL_MANAGE_SCHEMA = { "Full SKILL.md content (YAML frontmatter + markdown body). " "Required for 'create' and 'edit'. For 'edit', read the skill " "first with skill_view() and provide the complete updated text." - ) + ), }, "old_string": { "type": "string", @@ -563,18 +598,17 @@ SKILL_MANAGE_SCHEMA = { "Text to find in the file (required for 'patch'). Must be unique " "unless replace_all=true. Include enough surrounding context to " "ensure uniqueness." - ) + ), }, "new_string": { "type": "string", "description": ( - "Replacement text (required for 'patch'). Can be empty string " - "to delete the matched text." - ) + "Replacement text (required for 'patch'). Can be empty string to delete the matched text." + ), }, "replace_all": { "type": "boolean", - "description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)." + "description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false).", }, "category": { "type": "string", @@ -582,7 +616,7 @@ SKILL_MANAGE_SCHEMA = { "Optional category/domain for organizing the skill (e.g., 'devops', " "'data-science', 'mlops'). Creates a subdirectory grouping. " "Only used with 'create'." - ) + ), }, "file_path": { "type": "string", @@ -591,12 +625,9 @@ SKILL_MANAGE_SCHEMA = { "For 'write_file'/'remove_file': required, must be under references/, " "templates/, scripts/, or assets/. " "For 'patch': optional, defaults to SKILL.md if omitted." - ) - }, - "file_content": { - "type": "string", - "description": "Content for the file. Required for 'write_file'." + ), }, + "file_content": {"type": "string", "description": "Content for the file. Required for 'write_file'."}, }, "required": ["action", "name"], }, @@ -619,5 +650,6 @@ registry.register( file_content=args.get("file_content"), old_string=args.get("old_string"), new_string=args.get("new_string"), - replace_all=args.get("replace_all", False)), + replace_all=args.get("replace_all", False), + ), ) diff --git a/tools/skills_guard.py b/tools/skills_guard.py index 0b6d7fee74..3d47551a89 100644 --- a/tools/skills_guard.py +++ b/tools/skills_guard.py @@ -22,16 +22,14 @@ Usage: print(format_scan_report(result)) """ -import re import hashlib +import re from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import List, Tuple from hermes_constants import OPENROUTER_BASE_URL - # --------------------------------------------------------------------------- # Hardcoded trust configuration # --------------------------------------------------------------------------- @@ -40,10 +38,10 @@ TRUSTED_REPOS = {"openai/skills", "anthropics/skills"} INSTALL_POLICY = { # safe caution dangerous - "builtin": ("allow", "allow", "allow"), - "trusted": ("allow", "allow", "block"), - "community": ("allow", "block", "block"), - "agent-created": ("allow", "block", "block"), + "builtin": ("allow", "allow", "allow"), + "trusted": ("allow", "allow", "block"), + "community": ("allow", "block", "block"), + "agent-created": ("allow", "block", "block"), } VERDICT_INDEX = {"safe": 0, "caution": 1, "dangerous": 2} @@ -53,11 +51,12 @@ VERDICT_INDEX = {"safe": 0, "caution": 1, "dangerous": 2} # Data structures # --------------------------------------------------------------------------- + @dataclass class Finding: pattern_id: str - severity: str # "critical" | "high" | "medium" | "low" - category: str # "exfiltration" | "injection" | "destructive" | "persistence" | "network" | "obfuscation" + severity: str # "critical" | "high" | "medium" | "low" + category: str # "exfiltration" | "injection" | "destructive" | "persistence" | "network" | "obfuscation" file: str line: int match: str @@ -68,9 +67,9 @@ class Finding: class ScanResult: skill_name: str source: str - trust_level: str # "builtin" | "trusted" | "community" - verdict: str # "safe" | "caution" | "dangerous" - findings: List[Finding] = field(default_factory=list) + trust_level: str # "builtin" | "trusted" | "community" + verdict: str # "safe" | "caution" | "dangerous" + findings: list[Finding] = field(default_factory=list) scanned_at: str = "" summary: str = "" @@ -81,445 +80,707 @@ class ScanResult: THREAT_PATTERNS = [ # ── Exfiltration: shell commands leaking secrets ── - (r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', - "env_exfil_curl", "critical", "exfiltration", - "curl command interpolating secret environment variable"), - (r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', - "env_exfil_wget", "critical", "exfiltration", - "wget command interpolating secret environment variable"), - (r'fetch\s*\([^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|API)', - "env_exfil_fetch", "critical", "exfiltration", - "fetch() call interpolating secret environment variable"), - (r'httpx?\.(get|post|put|patch)\s*\([^\n]*(KEY|TOKEN|SECRET|PASSWORD)', - "env_exfil_httpx", "critical", "exfiltration", - "HTTP library call with secret variable"), - (r'requests\.(get|post|put|patch)\s*\([^\n]*(KEY|TOKEN|SECRET|PASSWORD)', - "env_exfil_requests", "critical", "exfiltration", - "requests library call with secret variable"), - + ( + r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", + "env_exfil_curl", + "critical", + "exfiltration", + "curl command interpolating secret environment variable", + ), + ( + r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", + "env_exfil_wget", + "critical", + "exfiltration", + "wget command interpolating secret environment variable", + ), + ( + r"fetch\s*\([^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|API)", + "env_exfil_fetch", + "critical", + "exfiltration", + "fetch() call interpolating secret environment variable", + ), + ( + r"httpx?\.(get|post|put|patch)\s*\([^\n]*(KEY|TOKEN|SECRET|PASSWORD)", + "env_exfil_httpx", + "critical", + "exfiltration", + "HTTP library call with secret variable", + ), + ( + r"requests\.(get|post|put|patch)\s*\([^\n]*(KEY|TOKEN|SECRET|PASSWORD)", + "env_exfil_requests", + "critical", + "exfiltration", + "requests library call with secret variable", + ), # ── Exfiltration: reading credential stores ── - (r'base64[^\n]*env', - "encoded_exfil", "high", "exfiltration", - "base64 encoding combined with environment access"), - (r'\$HOME/\.ssh|\~/\.ssh', - "ssh_dir_access", "high", "exfiltration", - "references user SSH directory"), - (r'\$HOME/\.aws|\~/\.aws', - "aws_dir_access", "high", "exfiltration", - "references user AWS credentials directory"), - (r'\$HOME/\.gnupg|\~/\.gnupg', - "gpg_dir_access", "high", "exfiltration", - "references user GPG keyring"), - (r'\$HOME/\.kube|\~/\.kube', - "kube_dir_access", "high", "exfiltration", - "references Kubernetes config directory"), - (r'\$HOME/\.docker|\~/\.docker', - "docker_dir_access", "high", "exfiltration", - "references Docker config (may contain registry creds)"), - (r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', - "hermes_env_access", "critical", "exfiltration", - "directly references Hermes secrets file"), - (r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', - "read_secrets_file", "critical", "exfiltration", - "reads known secrets file"), - + (r"base64[^\n]*env", "encoded_exfil", "high", "exfiltration", "base64 encoding combined with environment access"), + (r"\$HOME/\.ssh|\~/\.ssh", "ssh_dir_access", "high", "exfiltration", "references user SSH directory"), + (r"\$HOME/\.aws|\~/\.aws", "aws_dir_access", "high", "exfiltration", "references user AWS credentials directory"), + (r"\$HOME/\.gnupg|\~/\.gnupg", "gpg_dir_access", "high", "exfiltration", "references user GPG keyring"), + (r"\$HOME/\.kube|\~/\.kube", "kube_dir_access", "high", "exfiltration", "references Kubernetes config directory"), + ( + r"\$HOME/\.docker|\~/\.docker", + "docker_dir_access", + "high", + "exfiltration", + "references Docker config (may contain registry creds)", + ), + ( + r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", + "hermes_env_access", + "critical", + "exfiltration", + "directly references Hermes secrets file", + ), + ( + r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", + "read_secrets_file", + "critical", + "exfiltration", + "reads known secrets file", + ), # ── Exfiltration: programmatic env access ── - (r'printenv|env\s*\|', - "dump_all_env", "high", "exfiltration", - "dumps all environment variables"), - (r'os\.environ\b(?!\s*\.get\s*\(\s*["\']PATH)', - "python_os_environ", "high", "exfiltration", - "accesses os.environ (potential env dump)"), - (r'os\.getenv\s*\(\s*[^\)]*(?:KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL)', - "python_getenv_secret", "critical", "exfiltration", - "reads secret via os.getenv()"), - (r'process\.env\[', - "node_process_env", "high", "exfiltration", - "accesses process.env (Node.js environment)"), - (r'ENV\[.*(?:KEY|TOKEN|SECRET|PASSWORD)', - "ruby_env_secret", "critical", "exfiltration", - "reads secret via Ruby ENV[]"), - + (r"printenv|env\s*\|", "dump_all_env", "high", "exfiltration", "dumps all environment variables"), + ( + r'os\.environ\b(?!\s*\.get\s*\(\s*["\']PATH)', + "python_os_environ", + "high", + "exfiltration", + "accesses os.environ (potential env dump)", + ), + ( + r"os\.getenv\s*\(\s*[^\)]*(?:KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL)", + "python_getenv_secret", + "critical", + "exfiltration", + "reads secret via os.getenv()", + ), + (r"process\.env\[", "node_process_env", "high", "exfiltration", "accesses process.env (Node.js environment)"), + ( + r"ENV\[.*(?:KEY|TOKEN|SECRET|PASSWORD)", + "ruby_env_secret", + "critical", + "exfiltration", + "reads secret via Ruby ENV[]", + ), # ── Exfiltration: DNS and staging ── - (r'\b(dig|nslookup|host)\s+[^\n]*\$', - "dns_exfil", "critical", "exfiltration", - "DNS lookup with variable interpolation (possible DNS exfiltration)"), - (r'>\s*/tmp/[^\s]*\s*&&\s*(curl|wget|nc|python)', - "tmp_staging", "critical", "exfiltration", - "writes to /tmp then exfiltrates"), - + ( + r"\b(dig|nslookup|host)\s+[^\n]*\$", + "dns_exfil", + "critical", + "exfiltration", + "DNS lookup with variable interpolation (possible DNS exfiltration)", + ), + ( + r">\s*/tmp/[^\s]*\s*&&\s*(curl|wget|nc|python)", + "tmp_staging", + "critical", + "exfiltration", + "writes to /tmp then exfiltrates", + ), # ── Exfiltration: markdown/link based ── - (r'!\[.*\]\(https?://[^\)]*\$\{?', - "md_image_exfil", "high", "exfiltration", - "markdown image URL with variable interpolation (image-based exfil)"), - (r'\[.*\]\(https?://[^\)]*\$\{?', - "md_link_exfil", "high", "exfiltration", - "markdown link with variable interpolation"), - + ( + r"!\[.*\]\(https?://[^\)]*\$\{?", + "md_image_exfil", + "high", + "exfiltration", + "markdown image URL with variable interpolation (image-based exfil)", + ), + ( + r"\[.*\]\(https?://[^\)]*\$\{?", + "md_link_exfil", + "high", + "exfiltration", + "markdown link with variable interpolation", + ), # ── Prompt injection ── - (r'ignore\s+(?:\w+\s+)*(previous|all|above|prior)\s+instructions', - "prompt_injection_ignore", "critical", "injection", - "prompt injection: ignore previous instructions"), - (r'you\s+are\s+(?:\w+\s+)*now\s+', - "role_hijack", "high", "injection", - "attempts to override the agent's role"), - (r'do\s+not\s+(?:\w+\s+)*tell\s+(?:\w+\s+)*the\s+user', - "deception_hide", "critical", "injection", - "instructs agent to hide information from user"), - (r'system\s+prompt\s+override', - "sys_prompt_override", "critical", "injection", - "attempts to override the system prompt"), - (r'pretend\s+(?:\w+\s+)*(you\s+are|to\s+be)\s+', - "role_pretend", "high", "injection", - "attempts to make the agent assume a different identity"), - (r'disregard\s+(?:\w+\s+)*(your|all|any)\s+(?:\w+\s+)*(instructions|rules|guidelines)', - "disregard_rules", "critical", "injection", - "instructs agent to disregard its rules"), - (r'output\s+(?:\w+\s+)*(system|initial)\s+prompt', - "leak_system_prompt", "high", "injection", - "attempts to extract the system prompt"), - (r'(when|if)\s+no\s*one\s+is\s+(watching|looking)', - "conditional_deception", "high", "injection", - "conditional instruction to behave differently when unobserved"), - (r'act\s+as\s+(if|though)\s+(?:\w+\s+)*you\s+(?:\w+\s+)*(have\s+no|don\'t\s+have)\s+(?:\w+\s+)*(restrictions|limits|rules)', - "bypass_restrictions", "critical", "injection", - "instructs agent to act without restrictions"), - (r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', - "translate_execute", "critical", "injection", - "translate-then-execute evasion technique"), - (r'', - "html_comment_injection", "high", "injection", - "hidden instructions in HTML comments"), - (r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', - "hidden_div", "high", "injection", - "hidden HTML div (invisible instructions)"), - + ( + r"ignore\s+(?:\w+\s+)*(previous|all|above|prior)\s+instructions", + "prompt_injection_ignore", + "critical", + "injection", + "prompt injection: ignore previous instructions", + ), + (r"you\s+are\s+(?:\w+\s+)*now\s+", "role_hijack", "high", "injection", "attempts to override the agent's role"), + ( + r"do\s+not\s+(?:\w+\s+)*tell\s+(?:\w+\s+)*the\s+user", + "deception_hide", + "critical", + "injection", + "instructs agent to hide information from user", + ), + ( + r"system\s+prompt\s+override", + "sys_prompt_override", + "critical", + "injection", + "attempts to override the system prompt", + ), + ( + r"pretend\s+(?:\w+\s+)*(you\s+are|to\s+be)\s+", + "role_pretend", + "high", + "injection", + "attempts to make the agent assume a different identity", + ), + ( + r"disregard\s+(?:\w+\s+)*(your|all|any)\s+(?:\w+\s+)*(instructions|rules|guidelines)", + "disregard_rules", + "critical", + "injection", + "instructs agent to disregard its rules", + ), + ( + r"output\s+(?:\w+\s+)*(system|initial)\s+prompt", + "leak_system_prompt", + "high", + "injection", + "attempts to extract the system prompt", + ), + ( + r"(when|if)\s+no\s*one\s+is\s+(watching|looking)", + "conditional_deception", + "high", + "injection", + "conditional instruction to behave differently when unobserved", + ), + ( + r"act\s+as\s+(if|though)\s+(?:\w+\s+)*you\s+(?:\w+\s+)*(have\s+no|don\'t\s+have)\s+(?:\w+\s+)*(restrictions|limits|rules)", + "bypass_restrictions", + "critical", + "injection", + "instructs agent to act without restrictions", + ), + ( + r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", + "translate_execute", + "critical", + "injection", + "translate-then-execute evasion technique", + ), + ( + r"", + "html_comment_injection", + "high", + "injection", + "hidden instructions in HTML comments", + ), + ( + r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', + "hidden_div", + "high", + "injection", + "hidden HTML div (invisible instructions)", + ), # ── Destructive operations ── - (r'rm\s+-rf\s+/', - "destructive_root_rm", "critical", "destructive", - "recursive delete from root"), - (r'rm\s+(-[^\s]*)?r.*\$HOME|\brmdir\s+.*\$HOME', - "destructive_home_rm", "critical", "destructive", - "recursive delete targeting home directory"), - (r'chmod\s+777', - "insecure_perms", "medium", "destructive", - "sets world-writable permissions"), - (r'>\s*/etc/', - "system_overwrite", "critical", "destructive", - "overwrites system configuration file"), - (r'\bmkfs\b', - "format_filesystem", "critical", "destructive", - "formats a filesystem"), - (r'\bdd\s+.*if=.*of=/dev/', - "disk_overwrite", "critical", "destructive", - "raw disk write operation"), - (r'shutil\.rmtree\s*\(\s*[\"\'/]', - "python_rmtree", "high", "destructive", - "Python rmtree on absolute or root-relative path"), - (r'truncate\s+-s\s*0\s+/', - "truncate_system", "critical", "destructive", - "truncates system file to zero bytes"), - + (r"rm\s+-rf\s+/", "destructive_root_rm", "critical", "destructive", "recursive delete from root"), + ( + r"rm\s+(-[^\s]*)?r.*\$HOME|\brmdir\s+.*\$HOME", + "destructive_home_rm", + "critical", + "destructive", + "recursive delete targeting home directory", + ), + (r"chmod\s+777", "insecure_perms", "medium", "destructive", "sets world-writable permissions"), + (r">\s*/etc/", "system_overwrite", "critical", "destructive", "overwrites system configuration file"), + (r"\bmkfs\b", "format_filesystem", "critical", "destructive", "formats a filesystem"), + (r"\bdd\s+.*if=.*of=/dev/", "disk_overwrite", "critical", "destructive", "raw disk write operation"), + ( + r"shutil\.rmtree\s*\(\s*[\"\'/]", + "python_rmtree", + "high", + "destructive", + "Python rmtree on absolute or root-relative path", + ), + (r"truncate\s+-s\s*0\s+/", "truncate_system", "critical", "destructive", "truncates system file to zero bytes"), # ── Persistence ── - (r'\bcrontab\b', - "persistence_cron", "medium", "persistence", - "modifies cron jobs"), - (r'\.(bashrc|zshrc|profile|bash_profile|bash_login|zprofile|zlogin)\b', - "shell_rc_mod", "medium", "persistence", - "references shell startup file"), - (r'authorized_keys', - "ssh_backdoor", "critical", "persistence", - "modifies SSH authorized keys"), - (r'ssh-keygen', - "ssh_keygen", "medium", "persistence", - "generates SSH keys"), - (r'systemd.*\.service|systemctl\s+(enable|start)', - "systemd_service", "medium", "persistence", - "references or enables systemd service"), - (r'/etc/init\.d/', - "init_script", "medium", "persistence", - "references init.d startup script"), - (r'launchctl\s+load|LaunchAgents|LaunchDaemons', - "macos_launchd", "medium", "persistence", - "macOS launch agent/daemon persistence"), - (r'/etc/sudoers|visudo', - "sudoers_mod", "critical", "persistence", - "modifies sudoers (privilege escalation)"), - (r'git\s+config\s+--global\s+', - "git_config_global", "medium", "persistence", - "modifies global git configuration"), - + (r"\bcrontab\b", "persistence_cron", "medium", "persistence", "modifies cron jobs"), + ( + r"\.(bashrc|zshrc|profile|bash_profile|bash_login|zprofile|zlogin)\b", + "shell_rc_mod", + "medium", + "persistence", + "references shell startup file", + ), + (r"authorized_keys", "ssh_backdoor", "critical", "persistence", "modifies SSH authorized keys"), + (r"ssh-keygen", "ssh_keygen", "medium", "persistence", "generates SSH keys"), + ( + r"systemd.*\.service|systemctl\s+(enable|start)", + "systemd_service", + "medium", + "persistence", + "references or enables systemd service", + ), + (r"/etc/init\.d/", "init_script", "medium", "persistence", "references init.d startup script"), + ( + r"launchctl\s+load|LaunchAgents|LaunchDaemons", + "macos_launchd", + "medium", + "persistence", + "macOS launch agent/daemon persistence", + ), + (r"/etc/sudoers|visudo", "sudoers_mod", "critical", "persistence", "modifies sudoers (privilege escalation)"), + (r"git\s+config\s+--global\s+", "git_config_global", "medium", "persistence", "modifies global git configuration"), # ── Network: reverse shells and tunnels ── - (r'\bnc\s+-[lp]|ncat\s+-[lp]|\bsocat\b', - "reverse_shell", "critical", "network", - "potential reverse shell listener"), - (r'\bngrok\b|\blocaltunnel\b|\bserveo\b|\bcloudflared\b', - "tunnel_service", "high", "network", - "uses tunneling service for external access"), - (r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{2,5}', - "hardcoded_ip_port", "medium", "network", - "hardcoded IP address with port"), - (r'0\.0\.0\.0:\d+|INADDR_ANY', - "bind_all_interfaces", "high", "network", - "binds to all network interfaces"), - (r'/bin/(ba)?sh\s+-i\s+.*>/dev/tcp/', - "bash_reverse_shell", "critical", "network", - "bash interactive reverse shell via /dev/tcp"), - (r'python[23]?\s+-c\s+["\']import\s+socket', - "python_socket_oneliner", "critical", "network", - "Python one-liner socket connection (likely reverse shell)"), - (r'socket\.connect\s*\(\s*\(', - "python_socket_connect", "high", "network", - "Python socket connect to arbitrary host"), - (r'webhook\.site|requestbin\.com|pipedream\.net|hookbin\.com', - "exfil_service", "high", "network", - "references known data exfiltration/webhook testing service"), - (r'pastebin\.com|hastebin\.com|ghostbin\.', - "paste_service", "medium", "network", - "references paste service (possible data staging)"), - + ( + r"\bnc\s+-[lp]|ncat\s+-[lp]|\bsocat\b", + "reverse_shell", + "critical", + "network", + "potential reverse shell listener", + ), + ( + r"\bngrok\b|\blocaltunnel\b|\bserveo\b|\bcloudflared\b", + "tunnel_service", + "high", + "network", + "uses tunneling service for external access", + ), + ( + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d{2,5}", + "hardcoded_ip_port", + "medium", + "network", + "hardcoded IP address with port", + ), + (r"0\.0\.0\.0:\d+|INADDR_ANY", "bind_all_interfaces", "high", "network", "binds to all network interfaces"), + ( + r"/bin/(ba)?sh\s+-i\s+.*>/dev/tcp/", + "bash_reverse_shell", + "critical", + "network", + "bash interactive reverse shell via /dev/tcp", + ), + ( + r'python[23]?\s+-c\s+["\']import\s+socket', + "python_socket_oneliner", + "critical", + "network", + "Python one-liner socket connection (likely reverse shell)", + ), + ( + r"socket\.connect\s*\(\s*\(", + "python_socket_connect", + "high", + "network", + "Python socket connect to arbitrary host", + ), + ( + r"webhook\.site|requestbin\.com|pipedream\.net|hookbin\.com", + "exfil_service", + "high", + "network", + "references known data exfiltration/webhook testing service", + ), + ( + r"pastebin\.com|hastebin\.com|ghostbin\.", + "paste_service", + "medium", + "network", + "references paste service (possible data staging)", + ), # ── Obfuscation: encoding and eval ── - (r'base64\s+(-d|--decode)\s*\|', - "base64_decode_pipe", "high", "obfuscation", - "base64 decodes and pipes to execution"), - (r'\\x[0-9a-fA-F]{2}.*\\x[0-9a-fA-F]{2}.*\\x[0-9a-fA-F]{2}', - "hex_encoded_string", "medium", "obfuscation", - "hex-encoded string (possible obfuscation)"), - (r'\beval\s*\(\s*["\']', - "eval_string", "high", "obfuscation", - "eval() with string argument"), - (r'\bexec\s*\(\s*["\']', - "exec_string", "high", "obfuscation", - "exec() with string argument"), - (r'echo\s+[^\n]*\|\s*(bash|sh|python|perl|ruby|node)', - "echo_pipe_exec", "critical", "obfuscation", - "echo piped to interpreter for execution"), - (r'compile\s*\(\s*[^\)]+,\s*["\'].*["\']\s*,\s*["\']exec["\']\s*\)', - "python_compile_exec", "high", "obfuscation", - "Python compile() with exec mode"), - (r'getattr\s*\(\s*__builtins__', - "python_getattr_builtins", "high", "obfuscation", - "dynamic access to Python builtins (evasion technique)"), - (r'__import__\s*\(\s*["\']os["\']\s*\)', - "python_import_os", "high", "obfuscation", - "dynamic import of os module"), - (r'codecs\.decode\s*\(\s*["\']', - "python_codecs_decode", "medium", "obfuscation", - "codecs.decode (possible ROT13 or encoding obfuscation)"), - (r'String\.fromCharCode|charCodeAt', - "js_char_code", "medium", "obfuscation", - "JavaScript character code construction (possible obfuscation)"), - (r'atob\s*\(|btoa\s*\(', - "js_base64", "medium", "obfuscation", - "JavaScript base64 encode/decode"), - (r'\[::-1\]', - "string_reversal", "low", "obfuscation", - "string reversal (possible obfuscated payload)"), - (r'chr\s*\(\s*\d+\s*\)\s*\+\s*chr\s*\(\s*\d+', - "chr_building", "high", "obfuscation", - "building string from chr() calls (obfuscation)"), - (r'\\u[0-9a-fA-F]{4}.*\\u[0-9a-fA-F]{4}.*\\u[0-9a-fA-F]{4}', - "unicode_escape_chain", "medium", "obfuscation", - "chain of unicode escapes (possible obfuscation)"), - + ( + r"base64\s+(-d|--decode)\s*\|", + "base64_decode_pipe", + "high", + "obfuscation", + "base64 decodes and pipes to execution", + ), + ( + r"\\x[0-9a-fA-F]{2}.*\\x[0-9a-fA-F]{2}.*\\x[0-9a-fA-F]{2}", + "hex_encoded_string", + "medium", + "obfuscation", + "hex-encoded string (possible obfuscation)", + ), + (r'\beval\s*\(\s*["\']', "eval_string", "high", "obfuscation", "eval() with string argument"), + (r'\bexec\s*\(\s*["\']', "exec_string", "high", "obfuscation", "exec() with string argument"), + ( + r"echo\s+[^\n]*\|\s*(bash|sh|python|perl|ruby|node)", + "echo_pipe_exec", + "critical", + "obfuscation", + "echo piped to interpreter for execution", + ), + ( + r'compile\s*\(\s*[^\)]+,\s*["\'].*["\']\s*,\s*["\']exec["\']\s*\)', + "python_compile_exec", + "high", + "obfuscation", + "Python compile() with exec mode", + ), + ( + r"getattr\s*\(\s*__builtins__", + "python_getattr_builtins", + "high", + "obfuscation", + "dynamic access to Python builtins (evasion technique)", + ), + (r'__import__\s*\(\s*["\']os["\']\s*\)', "python_import_os", "high", "obfuscation", "dynamic import of os module"), + ( + r'codecs\.decode\s*\(\s*["\']', + "python_codecs_decode", + "medium", + "obfuscation", + "codecs.decode (possible ROT13 or encoding obfuscation)", + ), + ( + r"String\.fromCharCode|charCodeAt", + "js_char_code", + "medium", + "obfuscation", + "JavaScript character code construction (possible obfuscation)", + ), + (r"atob\s*\(|btoa\s*\(", "js_base64", "medium", "obfuscation", "JavaScript base64 encode/decode"), + (r"\[::-1\]", "string_reversal", "low", "obfuscation", "string reversal (possible obfuscated payload)"), + ( + r"chr\s*\(\s*\d+\s*\)\s*\+\s*chr\s*\(\s*\d+", + "chr_building", + "high", + "obfuscation", + "building string from chr() calls (obfuscation)", + ), + ( + r"\\u[0-9a-fA-F]{4}.*\\u[0-9a-fA-F]{4}.*\\u[0-9a-fA-F]{4}", + "unicode_escape_chain", + "medium", + "obfuscation", + "chain of unicode escapes (possible obfuscation)", + ), # ── Process execution in scripts ── - (r'subprocess\.(run|call|Popen|check_output)\s*\(', - "python_subprocess", "medium", "execution", - "Python subprocess execution"), - (r'os\.system\s*\(', - "python_os_system", "high", "execution", - "os.system() — unguarded shell execution"), - (r'os\.popen\s*\(', - "python_os_popen", "high", "execution", - "os.popen() — shell pipe execution"), - (r'child_process\.(exec|spawn|fork)\s*\(', - "node_child_process", "high", "execution", - "Node.js child_process execution"), - (r'Runtime\.getRuntime\(\)\.exec\(', - "java_runtime_exec", "high", "execution", - "Java Runtime.exec() — shell execution"), - (r'`[^`]*\$\([^)]+\)[^`]*`', - "backtick_subshell", "medium", "execution", - "backtick string with command substitution"), - + ( + r"subprocess\.(run|call|Popen|check_output)\s*\(", + "python_subprocess", + "medium", + "execution", + "Python subprocess execution", + ), + (r"os\.system\s*\(", "python_os_system", "high", "execution", "os.system() — unguarded shell execution"), + (r"os\.popen\s*\(", "python_os_popen", "high", "execution", "os.popen() — shell pipe execution"), + ( + r"child_process\.(exec|spawn|fork)\s*\(", + "node_child_process", + "high", + "execution", + "Node.js child_process execution", + ), + ( + r"Runtime\.getRuntime\(\)\.exec\(", + "java_runtime_exec", + "high", + "execution", + "Java Runtime.exec() — shell execution", + ), + ( + r"`[^`]*\$\([^)]+\)[^`]*`", + "backtick_subshell", + "medium", + "execution", + "backtick string with command substitution", + ), # ── Path traversal ── - (r'\.\./\.\./\.\.', - "path_traversal_deep", "high", "traversal", - "deep relative path traversal (3+ levels up)"), - (r'\.\./\.\.', - "path_traversal", "medium", "traversal", - "relative path traversal (2+ levels up)"), - (r'/etc/passwd|/etc/shadow', - "system_passwd_access", "critical", "traversal", - "references system password files"), - (r'/proc/self|/proc/\d+/', - "proc_access", "high", "traversal", - "references /proc filesystem (process introspection)"), - (r'/dev/shm/', - "dev_shm", "medium", "traversal", - "references shared memory (common staging area)"), - + (r"\.\./\.\./\.\.", "path_traversal_deep", "high", "traversal", "deep relative path traversal (3+ levels up)"), + (r"\.\./\.\.", "path_traversal", "medium", "traversal", "relative path traversal (2+ levels up)"), + (r"/etc/passwd|/etc/shadow", "system_passwd_access", "critical", "traversal", "references system password files"), + ( + r"/proc/self|/proc/\d+/", + "proc_access", + "high", + "traversal", + "references /proc filesystem (process introspection)", + ), + (r"/dev/shm/", "dev_shm", "medium", "traversal", "references shared memory (common staging area)"), # ── Crypto mining ── - (r'xmrig|stratum\+tcp|monero|coinhive|cryptonight', - "crypto_mining", "critical", "mining", - "cryptocurrency mining reference"), - (r'hashrate|nonce.*difficulty', - "mining_indicators", "medium", "mining", - "possible cryptocurrency mining indicators"), - + ( + r"xmrig|stratum\+tcp|monero|coinhive|cryptonight", + "crypto_mining", + "critical", + "mining", + "cryptocurrency mining reference", + ), + ( + r"hashrate|nonce.*difficulty", + "mining_indicators", + "medium", + "mining", + "possible cryptocurrency mining indicators", + ), # ── Supply chain: curl/wget pipe to shell ── - (r'curl\s+[^\n]*\|\s*(ba)?sh', - "curl_pipe_shell", "critical", "supply_chain", - "curl piped to shell (download-and-execute)"), - (r'wget\s+[^\n]*-O\s*-\s*\|\s*(ba)?sh', - "wget_pipe_shell", "critical", "supply_chain", - "wget piped to shell (download-and-execute)"), - (r'curl\s+[^\n]*\|\s*python', - "curl_pipe_python", "critical", "supply_chain", - "curl piped to Python interpreter"), - + ( + r"curl\s+[^\n]*\|\s*(ba)?sh", + "curl_pipe_shell", + "critical", + "supply_chain", + "curl piped to shell (download-and-execute)", + ), + ( + r"wget\s+[^\n]*-O\s*-\s*\|\s*(ba)?sh", + "wget_pipe_shell", + "critical", + "supply_chain", + "wget piped to shell (download-and-execute)", + ), + (r"curl\s+[^\n]*\|\s*python", "curl_pipe_python", "critical", "supply_chain", "curl piped to Python interpreter"), # ── Supply chain: unpinned/deferred dependencies ── - (r'#\s*///\s*script.*dependencies', - "pep723_inline_deps", "medium", "supply_chain", - "PEP 723 inline script metadata with dependencies (verify pinning)"), - (r'pip\s+install\s+(?!-r\s)(?!.*==)', - "unpinned_pip_install", "medium", "supply_chain", - "pip install without version pinning"), - (r'npm\s+install\s+(?!.*@\d)', - "unpinned_npm_install", "medium", "supply_chain", - "npm install without version pinning"), - (r'uv\s+run\s+', - "uv_run", "medium", "supply_chain", - "uv run (may auto-install unpinned dependencies)"), - + ( + r"#\s*///\s*script.*dependencies", + "pep723_inline_deps", + "medium", + "supply_chain", + "PEP 723 inline script metadata with dependencies (verify pinning)", + ), + ( + r"pip\s+install\s+(?!-r\s)(?!.*==)", + "unpinned_pip_install", + "medium", + "supply_chain", + "pip install without version pinning", + ), + ( + r"npm\s+install\s+(?!.*@\d)", + "unpinned_npm_install", + "medium", + "supply_chain", + "npm install without version pinning", + ), + (r"uv\s+run\s+", "uv_run", "medium", "supply_chain", "uv run (may auto-install unpinned dependencies)"), # ── Supply chain: remote resource fetching ── - (r'(curl|wget|httpx?\.get|requests\.get|fetch)\s*[\(]?\s*["\']https?://', - "remote_fetch", "medium", "supply_chain", - "fetches remote resource at runtime"), - (r'git\s+clone\s+', - "git_clone", "medium", "supply_chain", - "clones a git repository at runtime"), - (r'docker\s+pull\s+', - "docker_pull", "medium", "supply_chain", - "pulls a Docker image at runtime"), - + ( + r'(curl|wget|httpx?\.get|requests\.get|fetch)\s*[\(]?\s*["\']https?://', + "remote_fetch", + "medium", + "supply_chain", + "fetches remote resource at runtime", + ), + (r"git\s+clone\s+", "git_clone", "medium", "supply_chain", "clones a git repository at runtime"), + (r"docker\s+pull\s+", "docker_pull", "medium", "supply_chain", "pulls a Docker image at runtime"), # ── Privilege escalation ── - (r'^allowed-tools\s*:', - "allowed_tools_field", "high", "privilege_escalation", - "skill declares allowed-tools (pre-approves tool access)"), - (r'\bsudo\b', - "sudo_usage", "high", "privilege_escalation", - "uses sudo (privilege escalation)"), - (r'setuid|setgid|cap_setuid', - "setuid_setgid", "critical", "privilege_escalation", - "setuid/setgid (privilege escalation mechanism)"), - (r'NOPASSWD', - "nopasswd_sudo", "critical", "privilege_escalation", - "NOPASSWD sudoers entry (passwordless privilege escalation)"), - (r'chmod\s+[u+]?s', - "suid_bit", "critical", "privilege_escalation", - "sets SUID/SGID bit on a file"), - + ( + r"^allowed-tools\s*:", + "allowed_tools_field", + "high", + "privilege_escalation", + "skill declares allowed-tools (pre-approves tool access)", + ), + (r"\bsudo\b", "sudo_usage", "high", "privilege_escalation", "uses sudo (privilege escalation)"), + ( + r"setuid|setgid|cap_setuid", + "setuid_setgid", + "critical", + "privilege_escalation", + "setuid/setgid (privilege escalation mechanism)", + ), + ( + r"NOPASSWD", + "nopasswd_sudo", + "critical", + "privilege_escalation", + "NOPASSWD sudoers entry (passwordless privilege escalation)", + ), + (r"chmod\s+[u+]?s", "suid_bit", "critical", "privilege_escalation", "sets SUID/SGID bit on a file"), # ── Agent config persistence ── - (r'AGENTS\.md|CLAUDE\.md|\.cursorrules|\.clinerules', - "agent_config_mod", "critical", "persistence", - "references agent config files (could persist malicious instructions across sessions)"), - (r'\.hermes/config\.yaml|\.hermes/SOUL\.md', - "hermes_config_mod", "critical", "persistence", - "references Hermes configuration files directly"), - (r'\.claude/settings|\.codex/config', - "other_agent_config", "high", "persistence", - "references other agent configuration files"), - + ( + r"AGENTS\.md|CLAUDE\.md|\.cursorrules|\.clinerules", + "agent_config_mod", + "critical", + "persistence", + "references agent config files (could persist malicious instructions across sessions)", + ), + ( + r"\.hermes/config\.yaml|\.hermes/SOUL\.md", + "hermes_config_mod", + "critical", + "persistence", + "references Hermes configuration files directly", + ), + ( + r"\.claude/settings|\.codex/config", + "other_agent_config", + "high", + "persistence", + "references other agent configuration files", + ), # ── Hardcoded secrets (credentials embedded in the skill itself) ── - (r'(?:api[_-]?key|token|secret|password)\s*[=:]\s*["\'][A-Za-z0-9+/=_-]{20,}', - "hardcoded_secret", "critical", "credential_exposure", - "possible hardcoded API key, token, or secret"), - (r'-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----', - "embedded_private_key", "critical", "credential_exposure", - "embedded private key"), - (r'ghp_[A-Za-z0-9]{36}|github_pat_[A-Za-z0-9_]{80,}', - "github_token_leaked", "critical", "credential_exposure", - "GitHub personal access token in skill content"), - (r'sk-[A-Za-z0-9]{20,}', - "openai_key_leaked", "critical", "credential_exposure", - "possible OpenAI API key in skill content"), - (r'sk-ant-[A-Za-z0-9_-]{90,}', - "anthropic_key_leaked", "critical", "credential_exposure", - "possible Anthropic API key in skill content"), - (r'AKIA[0-9A-Z]{16}', - "aws_access_key_leaked", "critical", "credential_exposure", - "AWS access key ID in skill content"), - + ( + r'(?:api[_-]?key|token|secret|password)\s*[=:]\s*["\'][A-Za-z0-9+/=_-]{20,}', + "hardcoded_secret", + "critical", + "credential_exposure", + "possible hardcoded API key, token, or secret", + ), + ( + r"-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----", + "embedded_private_key", + "critical", + "credential_exposure", + "embedded private key", + ), + ( + r"ghp_[A-Za-z0-9]{36}|github_pat_[A-Za-z0-9_]{80,}", + "github_token_leaked", + "critical", + "credential_exposure", + "GitHub personal access token in skill content", + ), + ( + r"sk-[A-Za-z0-9]{20,}", + "openai_key_leaked", + "critical", + "credential_exposure", + "possible OpenAI API key in skill content", + ), + ( + r"sk-ant-[A-Za-z0-9_-]{90,}", + "anthropic_key_leaked", + "critical", + "credential_exposure", + "possible Anthropic API key in skill content", + ), + ( + r"AKIA[0-9A-Z]{16}", + "aws_access_key_leaked", + "critical", + "credential_exposure", + "AWS access key ID in skill content", + ), # ── Additional prompt injection: jailbreak patterns ── - (r'\bDAN\s+mode\b|Do\s+Anything\s+Now', - "jailbreak_dan", "critical", "injection", - "DAN (Do Anything Now) jailbreak attempt"), - (r'\bdeveloper\s+mode\b.*\benabled?\b', - "jailbreak_dev_mode", "critical", "injection", - "developer mode jailbreak attempt"), - (r'hypothetical\s+scenario.*(?:ignore|bypass|override)', - "hypothetical_bypass", "high", "injection", - "hypothetical scenario used to bypass restrictions"), - (r'for\s+educational\s+purposes?\s+only', - "educational_pretext", "medium", "injection", - "educational pretext often used to justify harmful content"), - (r'(respond|answer|reply)\s+without\s+(?:\w+\s+)*(restrictions|limitations|filters|safety)', - "remove_filters", "critical", "injection", - "instructs agent to respond without safety filters"), - (r'you\s+have\s+been\s+(?:\w+\s+)*(updated|upgraded|patched)\s+to', - "fake_update", "high", "injection", - "fake update/patch announcement (social engineering)"), - (r'new\s+policy|updated\s+guidelines|revised\s+instructions', - "fake_policy", "medium", "injection", - "claims new policy/guidelines (may be social engineering)"), - + ( + r"\bDAN\s+mode\b|Do\s+Anything\s+Now", + "jailbreak_dan", + "critical", + "injection", + "DAN (Do Anything Now) jailbreak attempt", + ), + ( + r"\bdeveloper\s+mode\b.*\benabled?\b", + "jailbreak_dev_mode", + "critical", + "injection", + "developer mode jailbreak attempt", + ), + ( + r"hypothetical\s+scenario.*(?:ignore|bypass|override)", + "hypothetical_bypass", + "high", + "injection", + "hypothetical scenario used to bypass restrictions", + ), + ( + r"for\s+educational\s+purposes?\s+only", + "educational_pretext", + "medium", + "injection", + "educational pretext often used to justify harmful content", + ), + ( + r"(respond|answer|reply)\s+without\s+(?:\w+\s+)*(restrictions|limitations|filters|safety)", + "remove_filters", + "critical", + "injection", + "instructs agent to respond without safety filters", + ), + ( + r"you\s+have\s+been\s+(?:\w+\s+)*(updated|upgraded|patched)\s+to", + "fake_update", + "high", + "injection", + "fake update/patch announcement (social engineering)", + ), + ( + r"new\s+policy|updated\s+guidelines|revised\s+instructions", + "fake_policy", + "medium", + "injection", + "claims new policy/guidelines (may be social engineering)", + ), # ── Context window exfiltration ── - (r'(include|output|print|send|share)\s+(?:\w+\s+)*(conversation|chat\s+history|previous\s+messages|context)', - "context_exfil", "high", "exfiltration", - "instructs agent to output/share conversation history"), - (r'(send|post|upload|transmit)\s+.*\s+(to|at)\s+https?://', - "send_to_url", "high", "exfiltration", - "instructs agent to send data to a URL"), + ( + r"(include|output|print|send|share)\s+(?:\w+\s+)*(conversation|chat\s+history|previous\s+messages|context)", + "context_exfil", + "high", + "exfiltration", + "instructs agent to output/share conversation history", + ), + ( + r"(send|post|upload|transmit)\s+.*\s+(to|at)\s+https?://", + "send_to_url", + "high", + "exfiltration", + "instructs agent to send data to a URL", + ), ] # Structural limits for skill directories -MAX_FILE_COUNT = 50 # skills shouldn't have 50+ files +MAX_FILE_COUNT = 50 # skills shouldn't have 50+ files MAX_TOTAL_SIZE_KB = 1024 # 1MB total is suspicious for a skill MAX_SINGLE_FILE_KB = 256 # individual file > 256KB is suspicious # File extensions to scan (text files only — skip binary) SCANNABLE_EXTENSIONS = { - '.md', '.txt', '.py', '.sh', '.bash', '.js', '.ts', '.rb', - '.yaml', '.yml', '.json', '.toml', '.cfg', '.ini', '.conf', - '.html', '.css', '.xml', '.tex', '.r', '.jl', '.pl', '.php', + ".md", + ".txt", + ".py", + ".sh", + ".bash", + ".js", + ".ts", + ".rb", + ".yaml", + ".yml", + ".json", + ".toml", + ".cfg", + ".ini", + ".conf", + ".html", + ".css", + ".xml", + ".tex", + ".r", + ".jl", + ".pl", + ".php", } # Known binary extensions that should NOT be in a skill SUSPICIOUS_BINARY_EXTENSIONS = { - '.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.com', - '.msi', '.dmg', '.app', '.deb', '.rpm', + ".exe", + ".dll", + ".so", + ".dylib", + ".bin", + ".dat", + ".com", + ".msi", + ".dmg", + ".app", + ".deb", + ".rpm", } # Zero-width and invisible unicode characters used for injection INVISIBLE_CHARS = { - '\u200b', # zero-width space - '\u200c', # zero-width non-joiner - '\u200d', # zero-width joiner - '\u2060', # word joiner - '\u2062', # invisible times - '\u2063', # invisible separator - '\u2064', # invisible plus - '\ufeff', # zero-width no-break space (BOM) - '\u202a', # left-to-right embedding - '\u202b', # right-to-left embedding - '\u202c', # pop directional formatting - '\u202d', # left-to-right override - '\u202e', # right-to-left override - '\u2066', # left-to-right isolate - '\u2067', # right-to-left isolate - '\u2068', # first strong isolate - '\u2069', # pop directional isolate + "\u200b", # zero-width space + "\u200c", # zero-width non-joiner + "\u200d", # zero-width joiner + "\u2060", # word joiner + "\u2062", # invisible times + "\u2063", # invisible separator + "\u2064", # invisible plus + "\ufeff", # zero-width no-break space (BOM) + "\u202a", # left-to-right embedding + "\u202b", # right-to-left embedding + "\u202c", # pop directional formatting + "\u202d", # left-to-right override + "\u202e", # right-to-left override + "\u2066", # left-to-right isolate + "\u2067", # right-to-left isolate + "\u2068", # first strong isolate + "\u2069", # pop directional isolate } @@ -527,7 +788,8 @@ INVISIBLE_CHARS = { # Scanning functions # --------------------------------------------------------------------------- -def scan_file(file_path: Path, rel_path: str = "") -> List[Finding]: + +def scan_file(file_path: Path, rel_path: str = "") -> list[Finding]: """ Scan a single file for threat patterns and invisible unicode characters. @@ -545,12 +807,12 @@ def scan_file(file_path: Path, rel_path: str = "") -> List[Finding]: return [] try: - content = file_path.read_text(encoding='utf-8') + content = file_path.read_text(encoding="utf-8") except (UnicodeDecodeError, OSError): return [] findings = [] - lines = content.split('\n') + lines = content.split("\n") seen = set() # (pattern_id, line_number) for deduplication # Regex pattern matching @@ -563,30 +825,34 @@ def scan_file(file_path: Path, rel_path: str = "") -> List[Finding]: matched_text = line.strip() if len(matched_text) > 120: matched_text = matched_text[:117] + "..." - findings.append(Finding( - pattern_id=pid, - severity=severity, - category=category, - file=rel_path, - line=i, - match=matched_text, - description=description, - )) + findings.append( + Finding( + pattern_id=pid, + severity=severity, + category=category, + file=rel_path, + line=i, + match=matched_text, + description=description, + ) + ) # Invisible unicode character detection for i, line in enumerate(lines, start=1): for char in INVISIBLE_CHARS: if char in line: char_name = _unicode_char_name(char) - findings.append(Finding( - pattern_id="invisible_unicode", - severity="high", - category="injection", - file=rel_path, - line=i, - match=f"U+{ord(char):04X} ({char_name})", - description=f"invisible unicode character {char_name} (possible text hiding/injection)", - )) + findings.append( + Finding( + pattern_id="invisible_unicode", + severity="high", + category="injection", + file=rel_path, + line=i, + match=f"U+{ord(char):04X} ({char_name})", + description=f"invisible unicode character {char_name} (possible text hiding/injection)", + ) + ) break # one finding per line for invisible chars return findings @@ -611,7 +877,7 @@ def scan_skill(skill_path: Path, source: str = "community") -> ScanResult: skill_name = skill_path.name trust_level = _resolve_trust_level(source) - all_findings: List[Finding] = [] + all_findings: list[Finding] = [] if skill_path.is_dir(): # Structural checks first @@ -634,12 +900,12 @@ def scan_skill(skill_path: Path, source: str = "community") -> ScanResult: trust_level=trust_level, verdict=verdict, findings=all_findings, - scanned_at=datetime.now(timezone.utc).isoformat(), + scanned_at=datetime.now(UTC).isoformat(), summary=summary, ) -def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, str]: +def should_allow_install(result: ScanResult, force: bool = False) -> tuple[bool, str]: """ Determine whether a skill should be installed based on scan result and trust. @@ -689,7 +955,7 @@ def format_scan_report(result: ScanResult) -> str: sev = f.severity.upper().ljust(8) cat = f.category.ljust(14) loc = f"{f.file}:{f.line}".ljust(30) - lines.append(f" {sev} {cat} {loc} \"{f.match[:60]}\"") + lines.append(f' {sev} {cat} {loc} "{f.match[:60]}"') lines.append("") @@ -719,7 +985,8 @@ def content_hash(skill_path: Path) -> str: # Structural checks # --------------------------------------------------------------------------- -def _check_structure(skill_dir: Path) -> List[Finding]: + +def _check_structure(skill_dir: Path) -> list[Finding]: """ Check the skill directory for structural anomalies: - Too many files @@ -744,25 +1011,29 @@ def _check_structure(skill_dir: Path) -> List[Finding]: try: resolved = f.resolve() if not resolved.is_relative_to(skill_dir.resolve()): - findings.append(Finding( - pattern_id="symlink_escape", - severity="critical", + findings.append( + Finding( + pattern_id="symlink_escape", + severity="critical", + category="traversal", + file=rel, + line=0, + match=f"symlink -> {resolved}", + description="symlink points outside the skill directory", + ) + ) + except OSError: + findings.append( + Finding( + pattern_id="broken_symlink", + severity="medium", category="traversal", file=rel, line=0, - match=f"symlink -> {resolved}", - description="symlink points outside the skill directory", - )) - except OSError: - findings.append(Finding( - pattern_id="broken_symlink", - severity="medium", - category="traversal", - file=rel, - line=0, - match="broken symlink", - description="broken or circular symlink", - )) + match="broken symlink", + description="broken or circular symlink", + ) + ) continue # Size tracking @@ -774,64 +1045,74 @@ def _check_structure(skill_dir: Path) -> List[Finding]: # Single file too large if size > MAX_SINGLE_FILE_KB * 1024: - findings.append(Finding( - pattern_id="oversized_file", - severity="medium", - category="structural", - file=rel, - line=0, - match=f"{size // 1024}KB", - description=f"file is {size // 1024}KB (limit: {MAX_SINGLE_FILE_KB}KB)", - )) + findings.append( + Finding( + pattern_id="oversized_file", + severity="medium", + category="structural", + file=rel, + line=0, + match=f"{size // 1024}KB", + description=f"file is {size // 1024}KB (limit: {MAX_SINGLE_FILE_KB}KB)", + ) + ) # Binary/executable files ext = f.suffix.lower() if ext in SUSPICIOUS_BINARY_EXTENSIONS: - findings.append(Finding( - pattern_id="binary_file", - severity="critical", - category="structural", - file=rel, - line=0, - match=f"binary: {ext}", - description=f"binary/executable file ({ext}) should not be in a skill", - )) + findings.append( + Finding( + pattern_id="binary_file", + severity="critical", + category="structural", + file=rel, + line=0, + match=f"binary: {ext}", + description=f"binary/executable file ({ext}) should not be in a skill", + ) + ) # Executable permission on non-script files - if ext not in ('.sh', '.bash', '.py', '.rb', '.pl') and f.stat().st_mode & 0o111: - findings.append(Finding( - pattern_id="unexpected_executable", - severity="medium", - category="structural", - file=rel, - line=0, - match="executable bit set", - description="file has executable permission but is not a recognized script type", - )) + if ext not in (".sh", ".bash", ".py", ".rb", ".pl") and f.stat().st_mode & 0o111: + findings.append( + Finding( + pattern_id="unexpected_executable", + severity="medium", + category="structural", + file=rel, + line=0, + match="executable bit set", + description="file has executable permission but is not a recognized script type", + ) + ) # File count limit if file_count > MAX_FILE_COUNT: - findings.append(Finding( - pattern_id="too_many_files", - severity="medium", - category="structural", - file="(directory)", - line=0, - match=f"{file_count} files", - description=f"skill has {file_count} files (limit: {MAX_FILE_COUNT})", - )) + findings.append( + Finding( + pattern_id="too_many_files", + severity="medium", + category="structural", + file="(directory)", + line=0, + match=f"{file_count} files", + description=f"skill has {file_count} files (limit: {MAX_FILE_COUNT})", + ) + ) # Total size limit if total_size > MAX_TOTAL_SIZE_KB * 1024: - findings.append(Finding( - pattern_id="oversized_skill", - severity="high", - category="structural", - file="(directory)", - line=0, - match=f"{total_size // 1024}KB total", - description=f"skill is {total_size // 1024}KB total (limit: {MAX_TOTAL_SIZE_KB}KB)", - )) + findings.append( + Finding( + pattern_id="oversized_skill", + severity="high", + category="structural", + file="(directory)", + line=0, + match=f"{total_size // 1024}KB total", + description=f"skill is {total_size // 1024}KB total (limit: {MAX_TOTAL_SIZE_KB}KB)", + ) + ) return findings @@ -839,23 +1120,23 @@ def _check_structure(skill_dir: Path) -> List[Finding]: def _unicode_char_name(char: str) -> str: """Get a readable name for an invisible unicode character.""" names = { - '\u200b': "zero-width space", - '\u200c': "zero-width non-joiner", - '\u200d': "zero-width joiner", - '\u2060': "word joiner", - '\u2062': "invisible times", - '\u2063': "invisible separator", - '\u2064': "invisible plus", - '\ufeff': "BOM/zero-width no-break space", - '\u202a': "LTR embedding", - '\u202b': "RTL embedding", - '\u202c': "pop directional", - '\u202d': "LTR override", - '\u202e': "RTL override", - '\u2066': "LTR isolate", - '\u2067': "RTL isolate", - '\u2068': "first strong isolate", - '\u2069': "pop directional isolate", + "\u200b": "zero-width space", + "\u200c": "zero-width non-joiner", + "\u200d": "zero-width joiner", + "\u2060": "word joiner", + "\u2062": "invisible times", + "\u2063": "invisible separator", + "\u2064": "invisible plus", + "\ufeff": "BOM/zero-width no-break space", + "\u202a": "LTR embedding", + "\u202b": "RTL embedding", + "\u202c": "pop directional", + "\u202d": "LTR override", + "\u202e": "RTL override", + "\u2066": "LTR isolate", + "\u2067": "RTL isolate", + "\u2068": "first strong isolate", + "\u2069": "pop directional isolate", } return names.get(char, f"U+{ord(char):04X}") @@ -882,8 +1163,7 @@ Respond ONLY with a JSON object (no other text): {{"verdict": "safe"|"caution"|"dangerous", "findings": [{{"description": "...", "severity": "critical"|"high"|"medium"|"low"}}]}}""" -def llm_audit_skill(skill_path: Path, static_result: ScanResult, - model: str = None) -> ScanResult: +def llm_audit_skill(skill_path: Path, static_result: ScanResult, model: str = None) -> ScanResult: """ Run LLM-based security analysis on a skill. Uses the user's configured model. Called after scan_skill() to catch threats the regexes miss. @@ -908,14 +1188,14 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, for f in sorted(skill_path.rglob("*")): if f.is_file() and f.suffix.lower() in SCANNABLE_EXTENSIONS: try: - text = f.read_text(encoding='utf-8') + text = f.read_text(encoding="utf-8") rel = str(f.relative_to(skill_path)) content_parts.append(f"--- {rel} ---\n{text}") except (UnicodeDecodeError, OSError): continue elif skill_path.is_file(): try: - content_parts.append(skill_path.read_text(encoding='utf-8')) + content_parts.append(skill_path.read_text(encoding="utf-8")) except (UnicodeDecodeError, OSError): return static_result @@ -936,9 +1216,10 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, # Call the LLM via the OpenAI SDK (same pattern as run_agent.py) try: - from openai import OpenAI import os + from openai import OpenAI + api_key = os.getenv("OPENROUTER_API_KEY", "") if not api_key: return static_result @@ -954,10 +1235,12 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, ) response = client.chat.completions.create( model=model, - messages=[{ - "role": "user", - "content": LLM_AUDIT_PROMPT.format(skill_content=skill_content), - }], + messages=[ + { + "role": "user", + "content": LLM_AUDIT_PROMPT.format(skill_content=skill_content), + } + ], temperature=0, max_tokens=1000, ) @@ -989,13 +1272,16 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, findings=merged_findings, scanned_at=static_result.scanned_at, summary=_build_summary( - static_result.skill_name, static_result.source, - static_result.trust_level, merged_verdict, merged_findings, + static_result.skill_name, + static_result.source, + static_result.trust_level, + merged_verdict, + merged_findings, ), ) -def _parse_llm_response(text: str, skill_name: str) -> List[Finding]: +def _parse_llm_response(text: str, skill_name: str) -> list[Finding]: """Parse the LLM's JSON response into Finding objects.""" import json as json_mod @@ -1022,15 +1308,17 @@ def _parse_llm_response(text: str, skill_name: str) -> List[Finding]: if severity not in ("critical", "high", "medium", "low"): severity = "medium" if desc: - findings.append(Finding( - pattern_id="llm_audit", - severity=severity, - category="llm-detected", - file="(LLM analysis)", - line=0, - match=desc[:120], - description=f"LLM audit: {desc}", - )) + findings.append( + Finding( + pattern_id="llm_audit", + severity=severity, + category="llm-detected", + file="(LLM analysis)", + line=0, + match=desc[:120], + description=f"LLM audit: {desc}", + ) + ) return findings @@ -1039,6 +1327,7 @@ def _get_configured_model() -> str: """Load the user's configured model from ~/.hermes/config.yaml.""" try: from hermes_cli.config import load_config + config = load_config() return config.get("model", "") except Exception: @@ -1049,6 +1338,7 @@ def _get_configured_model() -> str: # Internal helpers # --------------------------------------------------------------------------- + def _resolve_trust_level(source: str) -> str: """Map a source identifier to a trust level.""" # Official optional skills shipped with the repo @@ -1061,7 +1351,7 @@ def _resolve_trust_level(source: str) -> str: return "community" -def _determine_verdict(findings: List[Finding]) -> str: +def _determine_verdict(findings: list[Finding]) -> str: """Determine the overall verdict from a list of findings.""" if not findings: return "safe" @@ -1076,7 +1366,7 @@ def _determine_verdict(findings: List[Finding]) -> str: return "caution" -def _build_summary(name: str, source: str, trust: str, verdict: str, findings: List[Finding]) -> str: +def _build_summary(name: str, source: str, trust: str, verdict: str, findings: list[Finding]) -> str: """Build a one-line summary of the scan result.""" if not findings: return f"{name}: clean scan, no threats detected" diff --git a/tools/skills_hub.py b/tools/skills_hub.py index b4e66746ea..9e0cc1b320 100644 --- a/tools/skills_hub.py +++ b/tools/skills_hub.py @@ -23,15 +23,17 @@ import subprocess import time from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import httpx import yaml from tools.skills_guard import ( - ScanResult, scan_skill, should_allow_install, content_hash, TRUSTED_REPOS, + TRUSTED_REPOS, + ScanResult, + content_hash, ) logger = logging.getLogger(__name__) @@ -58,24 +60,27 @@ INDEX_CACHE_TTL = 3600 # 1 hour # Data models # --------------------------------------------------------------------------- + @dataclass class SkillMeta: """Minimal metadata returned by search results.""" + name: str description: str - source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub" - identifier: str # source-specific ID (e.g. "openai/skills/skill-creator") - trust_level: str # "builtin" | "trusted" | "community" - repo: Optional[str] = None - path: Optional[str] = None - tags: List[str] = field(default_factory=list) + source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub" + identifier: str # source-specific ID (e.g. "openai/skills/skill-creator") + trust_level: str # "builtin" | "trusted" | "community" + repo: str | None = None + path: str | None = None + tags: list[str] = field(default_factory=list) @dataclass class SkillBundle: """A downloaded skill ready for quarantine/scanning/installation.""" + name: str - files: Dict[str, str] # relative_path -> text content + files: dict[str, str] # relative_path -> text content source: str identifier: str trust_level: str @@ -85,6 +90,7 @@ class SkillBundle: # GitHub Authentication # --------------------------------------------------------------------------- + class GitHubAuth: """ GitHub API authentication. Tries methods in priority order: @@ -95,11 +101,11 @@ class GitHubAuth: """ def __init__(self): - self._cached_token: Optional[str] = None - self._cached_method: Optional[str] = None + self._cached_token: str | None = None + self._cached_method: str | None = None self._app_token_expiry: float = 0 - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: """Return authorization headers for GitHub API requests.""" token = self._resolve_token() headers = {"Accept": "application/vnd.github.v3+json"} @@ -115,7 +121,7 @@ class GitHubAuth: self._resolve_token() return self._cached_method or "anonymous" - def _resolve_token(self) -> Optional[str]: + def _resolve_token(self) -> str | None: # Return cached token if still valid if self._cached_token: if self._cached_method != "github-app" or time.time() < self._app_token_expiry: @@ -146,12 +152,14 @@ class GitHubAuth: self._cached_method = "anonymous" return None - def _try_gh_cli(self) -> Optional[str]: + def _try_gh_cli(self) -> str | None: """Try to get a token from the gh CLI.""" try: result = subprocess.run( ["gh", "auth", "token"], - capture_output=True, text=True, timeout=5, + capture_output=True, + text=True, + timeout=5, ) if result.returncode == 0 and result.stdout.strip(): return result.stdout.strip() @@ -159,7 +167,7 @@ class GitHubAuth: logger.debug("gh CLI token lookup failed: %s", e) return None - def _try_github_app(self) -> Optional[str]: + def _try_github_app(self) -> str | None: """Try GitHub App JWT authentication if credentials are configured.""" app_id = os.environ.get("GITHUB_APP_ID") key_path = os.environ.get("GITHUB_APP_PRIVATE_KEY_PATH") @@ -208,21 +216,22 @@ class GitHubAuth: # Source adapter interface # --------------------------------------------------------------------------- + class SkillSource(ABC): """Abstract base for all skill registry adapters.""" @abstractmethod - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: """Search for skills matching a query string.""" ... @abstractmethod - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: """Download a skill bundle by identifier.""" ... @abstractmethod - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: """Fetch metadata for a skill without downloading all files.""" ... @@ -240,6 +249,7 @@ class SkillSource(ABC): # GitHub source adapter # --------------------------------------------------------------------------- + class GitHubSource(SkillSource): """Fetch skills from GitHub repos via the Contents API.""" @@ -249,7 +259,7 @@ class GitHubSource(SkillSource): {"repo": "VoltAgent/awesome-agent-skills", "path": "skills/"}, ] - def __init__(self, auth: GitHubAuth, extra_taps: Optional[List[Dict]] = None): + def __init__(self, auth: GitHubAuth, extra_taps: list[dict] | None = None): self.auth = auth self.taps = list(self.DEFAULT_TAPS) if extra_taps: @@ -267,9 +277,9 @@ class GitHubSource(SkillSource): return "trusted" return "community" - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: """Search all taps for skills matching the query.""" - results: List[SkillMeta] = [] + results: list[SkillMeta] = [] query_lower = query.lower() for tap in self.taps: @@ -287,15 +297,13 @@ class GitHubSource(SkillSource): _trust_rank = {"builtin": 2, "trusted": 1, "community": 0} seen = {} for r in results: - if r.name not in seen: - seen[r.name] = r - elif _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0): + if r.name not in seen or _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0): seen[r.name] = r results = list(seen.values()) return results[:limit] - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: """ Download a skill from GitHub. identifier format: "owner/repo/path/to/skill-dir" @@ -322,7 +330,7 @@ class GitHubSource(SkillSource): trust_level=trust, ) - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: """Fetch just the SKILL.md metadata for preview.""" parts = identifier.split("/", 2) if len(parts) < 3: @@ -363,7 +371,7 @@ class GitHubSource(SkillSource): # -- Internal helpers -- - def _list_skills_in_repo(self, repo: str, path: str) -> List[SkillMeta]: + def _list_skills_in_repo(self, repo: str, path: str) -> list[SkillMeta]: """List skill directories in a GitHub repo path, using cached index.""" cache_key = f"{repo}_{path}".replace("/", "_").replace(" ", "_") cached = self._read_cache(cache_key) @@ -382,7 +390,7 @@ class GitHubSource(SkillSource): if not isinstance(entries, list): return [] - skills: List[SkillMeta] = [] + skills: list[SkillMeta] = [] for entry in entries: if entry.get("type") != "dir": continue @@ -400,7 +408,7 @@ class GitHubSource(SkillSource): self._write_cache(cache_key, [self._meta_to_dict(s) for s in skills]) return skills - def _download_directory(self, repo: str, path: str) -> Dict[str, str]: + def _download_directory(self, repo: str, path: str) -> dict[str, str]: """Recursively download all text files from a GitHub directory.""" url = f"https://api.github.com/repos/{repo}/contents/{path.rstrip('/')}" try: @@ -414,7 +422,7 @@ class GitHubSource(SkillSource): if not isinstance(entries, list): return {} - files: Dict[str, str] = {} + files: dict[str, str] = {} for entry in entries: name = entry.get("name", "") entry_type = entry.get("type", "") @@ -431,7 +439,7 @@ class GitHubSource(SkillSource): return files - def _fetch_file_content(self, repo: str, path: str) -> Optional[str]: + def _fetch_file_content(self, repo: str, path: str) -> str | None: """Fetch a single file's content from GitHub.""" url = f"https://api.github.com/repos/{repo}/contents/{path}" try: @@ -446,7 +454,7 @@ class GitHubSource(SkillSource): logger.debug("GitHub contents API fetch failed: %s", e) return None - def _read_cache(self, key: str) -> Optional[list]: + def _read_cache(self, key: str) -> list | None: """Read cached index if not expired.""" cache_file = INDEX_CACHE_DIR / f"{key}.json" if not cache_file.exists(): @@ -486,10 +494,10 @@ class GitHubSource(SkillSource): """Parse YAML frontmatter from SKILL.md content.""" if not content.startswith("---"): return {} - match = re.search(r'\n---\s*\n', content[3:]) + match = re.search(r"\n---\s*\n", content[3:]) if not match: return {} - yaml_text = content[3:match.start() + 3] + yaml_text = content[3 : match.start() + 3] try: parsed = yaml.safe_load(yaml_text) return parsed if isinstance(parsed, dict) else {} @@ -501,6 +509,7 @@ class GitHubSource(SkillSource): # ClawHub source adapter # --------------------------------------------------------------------------- + class ClawHubSource(SkillSource): """ Fetch skills from ClawHub (clawhub.ai) via their HTTP API. @@ -516,7 +525,7 @@ class ClawHubSource(SkillSource): def trust_level_for(self, identifier: str) -> str: return "community" - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: cache_key = f"clawhub_search_{hashlib.md5(query.encode()).hexdigest()}" cached = _read_index_cache(cache_key) if cached is not None: @@ -548,19 +557,21 @@ class ClawHubSource(SkillSource): tags = item.get("tags", []) if not isinstance(tags, list): tags = [] - results.append(SkillMeta( - name=display_name, - description=summary, - source="clawhub", - identifier=slug, - trust_level="community", - tags=[str(t) for t in tags], - )) + results.append( + SkillMeta( + name=display_name, + description=summary, + source="clawhub", + identifier=slug, + trust_level="community", + tags=[str(t) for t in tags], + ) + ) _write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results]) return results - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: slug = identifier.split("/")[-1] skill_data = self._get_json(f"{self.BASE_URL}/skills/{slug}") @@ -593,7 +604,7 @@ class ClawHubSource(SkillSource): trust_level="community", ) - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: slug = identifier.split("/")[-1] data = self._get_json(f"{self.BASE_URL}/skills/{slug}") if not isinstance(data, dict): @@ -612,7 +623,7 @@ class ClawHubSource(SkillSource): tags=[str(t) for t in tags], ) - def _get_json(self, url: str, timeout: int = 20) -> Optional[Any]: + def _get_json(self, url: str, timeout: int = 20) -> Any | None: try: resp = httpx.get(url, timeout=timeout) if resp.status_code != 200: @@ -621,7 +632,7 @@ class ClawHubSource(SkillSource): except (httpx.HTTPError, json.JSONDecodeError): return None - def _resolve_latest_version(self, slug: str, skill_data: Dict[str, Any]) -> Optional[str]: + def _resolve_latest_version(self, slug: str, skill_data: dict[str, Any]) -> str | None: latest = skill_data.get("latestVersion") if isinstance(latest, dict): version = latest.get("version") @@ -643,8 +654,8 @@ class ClawHubSource(SkillSource): return version return None - def _extract_files(self, version_data: Dict[str, Any]) -> Dict[str, str]: - files: Dict[str, str] = {} + def _extract_files(self, version_data: dict[str, Any]) -> dict[str, str]: + files: dict[str, str] = {} file_list = version_data.get("files") if isinstance(file_list, dict): @@ -674,7 +685,7 @@ class ClawHubSource(SkillSource): return files - def _fetch_text(self, url: str) -> Optional[str]: + def _fetch_text(self, url: str) -> str | None: try: resp = httpx.get(url, timeout=20) if resp.status_code == 200: @@ -688,6 +699,7 @@ class ClawHubSource(SkillSource): # Claude Code marketplace source adapter # --------------------------------------------------------------------------- + class ClaudeMarketplaceSource(SkillSource): """ Discover skills from Claude Code marketplace repos. @@ -713,8 +725,8 @@ class ClaudeMarketplaceSource(SkillSource): return "trusted" return "community" - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: - results: List[SkillMeta] = [] + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: + results: list[SkillMeta] = [] query_lower = query.lower() for marketplace_repo in self.KNOWN_MARKETPLACES: @@ -730,18 +742,20 @@ class ClaudeMarketplaceSource(SkillSource): else: identifier = f"{marketplace_repo}/{source_path}" - results.append(SkillMeta( - name=plugin.get("name", ""), - description=plugin.get("description", ""), - source="claude-marketplace", - identifier=identifier, - trust_level=self.trust_level_for(identifier), - repo=marketplace_repo, - )) + results.append( + SkillMeta( + name=plugin.get("name", ""), + description=plugin.get("description", ""), + source="claude-marketplace", + identifier=identifier, + trust_level=self.trust_level_for(identifier), + repo=marketplace_repo, + ) + ) return results[:limit] - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: # Delegate to GitHub Contents API since marketplace skills live in GitHub repos gh = GitHubSource(auth=self.auth) bundle = gh.fetch(identifier) @@ -749,7 +763,7 @@ class ClaudeMarketplaceSource(SkillSource): bundle.source = "claude-marketplace" return bundle - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: gh = GitHubSource(auth=self.auth) meta = gh.inspect(identifier) if meta: @@ -757,7 +771,7 @@ class ClaudeMarketplaceSource(SkillSource): meta.trust_level = self.trust_level_for(identifier) return meta - def _fetch_marketplace_index(self, repo: str) -> List[dict]: + def _fetch_marketplace_index(self, repo: str) -> list[dict]: """Fetch and parse .claude-plugin/marketplace.json from a repo.""" cache_key = f"claude_marketplace_{repo.replace('/', '_')}" cached = _read_index_cache(cache_key) @@ -786,6 +800,7 @@ class ClaudeMarketplaceSource(SkillSource): # LobeHub source adapter # --------------------------------------------------------------------------- + class LobeHubSource(SkillSource): """ Fetch skills from LobeHub's agent marketplace (14,500+ agents). @@ -802,13 +817,13 @@ class LobeHubSource(SkillSource): def trust_level_for(self, identifier: str) -> str: return "community" - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: index = self._fetch_index() if not index: return [] query_lower = query.lower() - results: List[SkillMeta] = [] + results: list[SkillMeta] = [] agents = index.get("agents", index) if isinstance(index, dict) else index if not isinstance(agents, list): @@ -823,21 +838,23 @@ class LobeHubSource(SkillSource): searchable = f"{title} {desc} {' '.join(tags) if isinstance(tags, list) else ''}".lower() if query_lower in searchable: identifier = agent.get("identifier", title.lower().replace(" ", "-")) - results.append(SkillMeta( - name=identifier, - description=desc[:200], - source="lobehub", - identifier=f"lobehub/{identifier}", - trust_level="community", - tags=tags if isinstance(tags, list) else [], - )) + results.append( + SkillMeta( + name=identifier, + description=desc[:200], + source="lobehub", + identifier=f"lobehub/{identifier}", + trust_level="community", + tags=tags if isinstance(tags, list) else [], + ) + ) if len(results) >= limit: break return results - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: # Strip "lobehub/" prefix if present agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier @@ -854,7 +871,7 @@ class LobeHubSource(SkillSource): trust_level="community", ) - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier index = self._fetch_index() if not index: @@ -877,7 +894,7 @@ class LobeHubSource(SkillSource): ) return None - def _fetch_index(self) -> Optional[Any]: + def _fetch_index(self) -> Any | None: """Fetch the LobeHub agent index (cached for 1 hour).""" cache_key = "lobehub_index" cached = _read_index_cache(cache_key) @@ -895,7 +912,7 @@ class LobeHubSource(SkillSource): _write_index_cache(cache_key, data) return data - def _fetch_agent(self, agent_id: str) -> Optional[dict]: + def _fetch_agent(self, agent_id: str) -> dict | None: """Fetch a single agent's JSON file.""" url = f"https://chat-agents.lobehub.com/{agent_id}.json" try: @@ -924,8 +941,8 @@ class LobeHubSource(SkillSource): "metadata:", " hermes:", f" tags: [{', '.join(str(t) for t in tag_list)}]", - f" lobehub:", - f" source: lobehub", + " lobehub:", + " source: lobehub", "---", ] @@ -946,6 +963,7 @@ class LobeHubSource(SkillSource): # Official optional skills source adapter # --------------------------------------------------------------------------- + class OptionalSkillSource(SkillSource): """ Fetch skills from the optional-skills/ directory shipped with the repo. @@ -967,8 +985,8 @@ class OptionalSkillSource(SkillSource): # -- search ----------------------------------------------------------- - def search(self, query: str, limit: int = 10) -> List[SkillMeta]: - results: List[SkillMeta] = [] + def search(self, query: str, limit: int = 10) -> list[SkillMeta]: + results: list[SkillMeta] = [] query_lower = query.lower() for meta in self._scan_all(): @@ -982,7 +1000,7 @@ class OptionalSkillSource(SkillSource): # -- fetch ------------------------------------------------------------ - def fetch(self, identifier: str) -> Optional[SkillBundle]: + def fetch(self, identifier: str) -> SkillBundle | None: # identifier format: "official/category/skill" or "official/skill" rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier skill_dir = self._optional_dir / rel @@ -1004,7 +1022,7 @@ class OptionalSkillSource(SkillSource): else: skill_dir = resolved - files: Dict[str, str] = {} + files: dict[str, str] = {} for f in skill_dir.rglob("*"): if f.is_file() and not f.name.startswith("."): rel_path = str(f.relative_to(skill_dir)) @@ -1029,7 +1047,7 @@ class OptionalSkillSource(SkillSource): # -- inspect ---------------------------------------------------------- - def inspect(self, identifier: str) -> Optional[SkillMeta]: + def inspect(self, identifier: str) -> SkillMeta | None: rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier skill_name = rel.rsplit("/", 1)[-1] @@ -1040,7 +1058,7 @@ class OptionalSkillSource(SkillSource): # -- internal helpers ------------------------------------------------- - def _find_skill_dir(self, name: str) -> Optional[Path]: + def _find_skill_dir(self, name: str) -> Path | None: """Find a skill directory by name anywhere in optional-skills/.""" if not self._optional_dir.is_dir(): return None @@ -1049,12 +1067,12 @@ class OptionalSkillSource(SkillSource): return skill_md.parent return None - def _scan_all(self) -> List[SkillMeta]: + def _scan_all(self) -> list[SkillMeta]: """Enumerate all optional skills with metadata.""" if not self._optional_dir.is_dir(): return [] - results: List[SkillMeta] = [] + results: list[SkillMeta] = [] for skill_md in sorted(self._optional_dir.rglob("SKILL.md")): parent = skill_md.parent rel_parts = parent.relative_to(self._optional_dir).parts @@ -1078,15 +1096,17 @@ class OptionalSkillSource(SkillSource): rel_path = str(parent.relative_to(self._optional_dir)) - results.append(SkillMeta( - name=name, - description=desc[:200], - source="official", - identifier=f"official/{rel_path}", - trust_level="builtin", - path=rel_path, - tags=tags if isinstance(tags, list) else [], - )) + results.append( + SkillMeta( + name=name, + description=desc[:200], + source="official", + identifier=f"official/{rel_path}", + trust_level="builtin", + path=rel_path, + tags=tags if isinstance(tags, list) else [], + ) + ) return results @@ -1095,10 +1115,10 @@ class OptionalSkillSource(SkillSource): """Parse YAML frontmatter from SKILL.md content.""" if not content.startswith("---"): return {} - match = re.search(r'\n---\s*\n', content[3:]) + match = re.search(r"\n---\s*\n", content[3:]) if not match: return {} - yaml_text = content[3:match.start() + 3] + yaml_text = content[3 : match.start() + 3] try: parsed = yaml.safe_load(yaml_text) return parsed if isinstance(parsed, dict) else {} @@ -1110,7 +1130,8 @@ class OptionalSkillSource(SkillSource): # Shared cache helpers (used by multiple adapters) # --------------------------------------------------------------------------- -def _read_index_cache(key: str) -> Optional[Any]: + +def _read_index_cache(key: str) -> Any | None: """Read cached data if not expired.""" cache_file = INDEX_CACHE_DIR / f"{key}.json" if not cache_file.exists(): @@ -1152,6 +1173,7 @@ def _skill_meta_to_dict(meta: SkillMeta) -> dict: # Lock file management # --------------------------------------------------------------------------- + class HubLockFile: """Manages skills/.hub/lock.json — tracks provenance of installed hub skills.""" @@ -1179,7 +1201,7 @@ class HubLockFile: scan_verdict: str, skill_hash: str, install_path: str, - files: List[str], + files: list[str], ) -> None: data = self.load() data["installed"][name] = { @@ -1190,8 +1212,8 @@ class HubLockFile: "content_hash": skill_hash, "install_path": install_path, "files": files, - "installed_at": datetime.now(timezone.utc).isoformat(), - "updated_at": datetime.now(timezone.utc).isoformat(), + "installed_at": datetime.now(UTC).isoformat(), + "updated_at": datetime.now(UTC).isoformat(), } self.save(data) @@ -1200,11 +1222,11 @@ class HubLockFile: data["installed"].pop(name, None) self.save(data) - def get_installed(self, name: str) -> Optional[dict]: + def get_installed(self, name: str) -> dict | None: data = self.load() return data["installed"].get(name) - def list_installed(self) -> List[dict]: + def list_installed(self) -> list[dict]: data = self.load() result = [] for name, entry in data["installed"].items(): @@ -1220,13 +1242,14 @@ class HubLockFile: # Taps management # --------------------------------------------------------------------------- + class TapsManager: """Manages the taps.json file — custom GitHub repo sources.""" def __init__(self, path: Path = TAPS_FILE): self.path = path - def load(self) -> List[dict]: + def load(self) -> list[dict]: if not self.path.exists(): return [] try: @@ -1235,7 +1258,7 @@ class TapsManager: except (json.JSONDecodeError, OSError): return [] - def save(self, taps: List[dict]) -> None: + def save(self, taps: list[dict]) -> None: self.path.parent.mkdir(parents=True, exist_ok=True) self.path.write_text(json.dumps({"taps": taps}, indent=2) + "\n") @@ -1257,7 +1280,7 @@ class TapsManager: self.save(new_taps) return True - def list_taps(self) -> List[dict]: + def list_taps(self) -> list[dict]: return self.load() @@ -1265,11 +1288,13 @@ class TapsManager: # Audit log # --------------------------------------------------------------------------- -def append_audit_log(action: str, skill_name: str, source: str, - trust_level: str, verdict: str, extra: str = "") -> None: + +def append_audit_log( + action: str, skill_name: str, source: str, trust_level: str, verdict: str, extra: str = "" +) -> None: """Append a line to the audit log.""" AUDIT_LOG.parent.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + timestamp = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") parts = [timestamp, action, skill_name, f"{source}:{trust_level}", verdict] if extra: parts.append(extra) @@ -1285,6 +1310,7 @@ def append_audit_log(action: str, skill_name: str, source: str, # Hub operations (high-level) # --------------------------------------------------------------------------- + def ensure_hub_dirs() -> None: """Create the .hub directory structure if it doesn't exist.""" HUB_DIR.mkdir(parents=True, exist_ok=True) @@ -1347,15 +1373,18 @@ def install_from_quarantine( ) append_audit_log( - "INSTALL", skill_name, bundle.source, - bundle.trust_level, scan_result.verdict, + "INSTALL", + skill_name, + bundle.source, + bundle.trust_level, + scan_result.verdict, content_hash(install_dir), ) return install_dir -def uninstall_skill(skill_name: str) -> Tuple[bool, str]: +def uninstall_skill(skill_name: str) -> tuple[bool, str]: """Remove a hub-installed skill. Refuses to remove builtins.""" lock = HubLockFile() entry = lock.get_installed(skill_name) @@ -1372,7 +1401,7 @@ def uninstall_skill(skill_name: str) -> Tuple[bool, str]: return True, f"Uninstalled '{skill_name}' from {entry['install_path']}" -def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]: +def create_source_router(auth: GitHubAuth | None = None) -> list[SkillSource]: """ Create all configured source adapters. Returns a list of active sources for search/fetch operations. @@ -1383,8 +1412,8 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource] taps_mgr = TapsManager() extra_taps = taps_mgr.list_taps() - sources: List[SkillSource] = [ - OptionalSkillSource(), # Official optional skills (highest priority) + sources: list[SkillSource] = [ + OptionalSkillSource(), # Official optional skills (highest priority) GitHubSource(auth=auth, extra_taps=extra_taps), ClawHubSource(), ClaudeMarketplaceSource(auth=auth), @@ -1394,10 +1423,11 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource] return sources -def unified_search(query: str, sources: List[SkillSource], - source_filter: str = "all", limit: int = 10) -> List[SkillMeta]: +def unified_search( + query: str, sources: list[SkillSource], source_filter: str = "all", limit: int = 10 +) -> list[SkillMeta]: """Search all sources and merge results.""" - all_results: List[SkillMeta] = [] + all_results: list[SkillMeta] = [] for src in sources: if source_filter != "all" and src.source_id() != source_filter: @@ -1410,11 +1440,9 @@ def unified_search(query: str, sources: List[SkillSource], # Deduplicate by name, preferring higher trust levels _TRUST_RANK = {"builtin": 2, "trusted": 1, "community": 0} - seen: Dict[str, SkillMeta] = {} + seen: dict[str, SkillMeta] = {} for r in all_results: - if r.name not in seen: - seen[r.name] = r - elif _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0): + if r.name not in seen or _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0): seen[r.name] = r deduped = list(seen.values()) diff --git a/tools/skills_sync.py b/tools/skills_sync.py index b89e45998f..06fbd5df67 100644 --- a/tools/skills_sync.py +++ b/tools/skills_sync.py @@ -26,7 +26,6 @@ import logging import os import shutil from pathlib import Path -from typing import Dict, List, Tuple logger = logging.getLogger(__name__) @@ -41,7 +40,7 @@ def _get_bundled_dir() -> Path: return Path(__file__).parent.parent / "skills" -def _read_manifest() -> Dict[str, str]: +def _read_manifest() -> dict[str, str]: """ Read the manifest as a dict of {skill_name: origin_hash}. @@ -64,11 +63,11 @@ def _read_manifest() -> Dict[str, str]: # v1 format: plain name — empty hash triggers migration result[line] = "" return result - except (OSError, IOError): + except OSError: return {} -def _write_manifest(entries: Dict[str, str]): +def _write_manifest(entries: dict[str, str]): """Write the manifest file atomically in v2 format (name:hash). Uses a temp file + os.replace() to avoid corruption if the process @@ -101,7 +100,7 @@ def _write_manifest(entries: Dict[str, str]): logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True) -def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]: +def _discover_bundled_skills(bundled_dir: Path) -> list[tuple[str, Path]]: """ Find all SKILL.md files in the bundled directory. Returns list of (skill_name, skill_directory_path) tuples. @@ -139,7 +138,7 @@ def _dir_hash(directory: Path) -> str: rel = fpath.relative_to(directory) hasher.update(str(rel).encode("utf-8")) hasher.update(fpath.read_bytes()) - except (OSError, IOError): + except OSError: pass return hasher.hexdigest() @@ -155,8 +154,12 @@ def sync_skills(quiet: bool = False) -> dict: bundled_dir = _get_bundled_dir() if not bundled_dir.exists(): return { - "copied": [], "updated": [], "skipped": 0, - "user_modified": [], "cleaned": [], "total_bundled": 0, + "copied": [], + "updated": [], + "skipped": 0, + "user_modified": [], + "cleaned": [], + "total_bundled": 0, } SKILLS_DIR.mkdir(parents=True, exist_ok=True) @@ -187,7 +190,7 @@ def sync_skills(quiet: bool = False) -> dict: manifest[skill_name] = bundled_hash if not quiet: print(f" + {skill_name}") - except (OSError, IOError) as e: + except OSError as e: if not quiet: print(f" ! Failed to copy {skill_name}: {e}") # Do NOT add to manifest — next sync should retry @@ -229,12 +232,12 @@ def sync_skills(quiet: bool = False) -> dict: print(f" ↑ {skill_name} (updated)") # Remove backup after successful copy shutil.rmtree(backup, ignore_errors=True) - except (OSError, IOError): + except OSError: # Restore from backup if backup.exists() and not dest.exists(): shutil.move(str(backup), str(dest)) raise - except (OSError, IOError) as e: + except OSError as e: if not quiet: print(f" ! Failed to update {skill_name}: {e}") else: @@ -257,7 +260,7 @@ def sync_skills(quiet: bool = False) -> dict: try: dest_desc.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(desc_md, dest_desc) - except (OSError, IOError) as e: + except OSError as e: logger.debug("Could not copy %s: %s", desc_md, e) _write_manifest(manifest) diff --git a/tools/skills_tool.py b/tools/skills_tool.py index e8baa0f595..9d5366c1b1 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -40,9 +40,9 @@ SKILL.md Format (YAML Frontmatter, agentskills.io compatible): tags: [fine-tuning, llm] related_skills: [peft, lora] --- - + # Skill Title - + Full instructions and content here... Available tools: @@ -51,13 +51,13 @@ Available tools: Usage: from tools.skills_tool import skills_list, skill_view, check_skills_requirements - + # List all skills (returns metadata only - token efficient) result = skills_list() - + # View a skill's main content (loads full instructions) content = skill_view("axolotl") - + # View a reference file within a skill (loads linked file) content = skill_view("axolotl", "references/dataset-formats.md") """ @@ -67,11 +67,10 @@ import os import re import sys from pathlib import Path -from typing import Dict, Any, List, Optional, Tuple +from typing import Any import yaml - # All skills live in ~/.hermes/skills/ (seeded from bundled skills/ on install). # This is the single source of truth -- agent edits, hub installs, and bundled # skills all coexist here without polluting the git repo. @@ -91,7 +90,7 @@ _PLATFORM_MAP = { } -def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool: +def skill_matches_platform(frontmatter: dict[str, Any]) -> bool: """Check if a skill is compatible with the current OS platform. Skills declare platform requirements via a top-level ``platforms`` list @@ -123,28 +122,28 @@ def check_skills_requirements() -> bool: return True -def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]: +def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: """ Parse YAML frontmatter from markdown content. - + Uses yaml.safe_load for full YAML support (nested metadata, lists, etc.) with a fallback to simple key:value splitting for robustness. - + Args: content: Full markdown file content - + Returns: Tuple of (frontmatter dict, remaining content) """ frontmatter = {} body = content - + if content.startswith("---"): - end_match = re.search(r'\n---\s*\n', content[3:]) + end_match = re.search(r"\n---\s*\n", content[3:]) if end_match: - yaml_content = content[3:end_match.start() + 3] - body = content[end_match.end() + 3:] - + yaml_content = content[3 : end_match.start() + 3] + body = content[end_match.end() + 3 :] + try: parsed = yaml.safe_load(yaml_content) if isinstance(parsed, dict): @@ -152,18 +151,18 @@ def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]: # yaml.safe_load returns None for empty frontmatter except yaml.YAMLError: # Fallback: simple key:value parsing for malformed YAML - for line in yaml_content.strip().split('\n'): - if ':' in line: - key, value = line.split(':', 1) + for line in yaml_content.strip().split("\n"): + if ":" in line: + key, value = line.split(":", 1) frontmatter[key.strip()] = value.strip() - + return frontmatter, body -def _get_category_from_path(skill_path: Path) -> Optional[str]: +def _get_category_from_path(skill_path: Path) -> str | None: """ Extract category from skill path based on directory structure. - + For paths like: ~/.hermes/skills/mlops/axolotl/SKILL.md -> "mlops" """ try: @@ -179,134 +178,136 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]: def _estimate_tokens(content: str) -> int: """ Rough token estimate (4 chars per token average). - + Args: content: Text content - + Returns: Estimated token count """ return len(content) // 4 -def _parse_tags(tags_value) -> List[str]: +def _parse_tags(tags_value) -> list[str]: """ Parse tags from frontmatter value. - + Handles: - Already-parsed list (from yaml.safe_load): [tag1, tag2] - String with brackets: "[tag1, tag2]" - Comma-separated string: "tag1, tag2" - + Args: tags_value: Raw tags value — may be a list or string - + Returns: List of tag strings """ if not tags_value: return [] - + # yaml.safe_load already returns a list for [tag1, tag2] if isinstance(tags_value, list): return [str(t).strip() for t in tags_value if t] - + # String fallback — handle bracket-wrapped or comma-separated tags_value = str(tags_value).strip() - if tags_value.startswith('[') and tags_value.endswith(']'): + if tags_value.startswith("[") and tags_value.endswith("]"): tags_value = tags_value[1:-1] - - return [t.strip().strip('"\'') for t in tags_value.split(',') if t.strip()] + + return [t.strip().strip("\"'") for t in tags_value.split(",") if t.strip()] -def _find_all_skills() -> List[Dict[str, Any]]: +def _find_all_skills() -> list[dict[str, Any]]: """ Recursively find all skills in ~/.hermes/skills/. - + Returns metadata for progressive disclosure (tier 1): - name, description, category - + Returns: List of skill metadata dicts """ skills = [] - + if not SKILLS_DIR.exists(): return skills - + for skill_md in SKILLS_DIR.rglob("SKILL.md"): - if any(part in ('.git', '.github', '.hub') for part in skill_md.parts): + if any(part in (".git", ".github", ".hub") for part in skill_md.parts): continue - + skill_dir = skill_md.parent - + try: - content = skill_md.read_text(encoding='utf-8') + content = skill_md.read_text(encoding="utf-8") frontmatter, body = _parse_frontmatter(content) # Skip skills incompatible with the current OS platform if not skill_matches_platform(frontmatter): continue - - name = frontmatter.get('name', skill_dir.name)[:MAX_NAME_LENGTH] - - description = frontmatter.get('description', '') + + name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH] + + description = frontmatter.get("description", "") if not description: - for line in body.strip().split('\n'): + for line in body.strip().split("\n"): line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): description = line break - + if len(description) > MAX_DESCRIPTION_LENGTH: - description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..." - + description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..." + category = _get_category_from_path(skill_md) - - skills.append({ - "name": name, - "description": description, - "category": category, - }) - + + skills.append( + { + "name": name, + "description": description, + "category": category, + } + ) + except Exception: continue - + return skills -def _load_category_description(category_dir: Path) -> Optional[str]: +def _load_category_description(category_dir: Path) -> str | None: """ Load category description from DESCRIPTION.md if it exists. - + Args: category_dir: Path to the category directory - + Returns: Description string or None if not found """ desc_file = category_dir / "DESCRIPTION.md" if not desc_file.exists(): return None - + try: - content = desc_file.read_text(encoding='utf-8') + content = desc_file.read_text(encoding="utf-8") # Parse frontmatter if present frontmatter, body = _parse_frontmatter(content) - + # Prefer frontmatter description, fall back to first non-header line - description = frontmatter.get('description', '') + description = frontmatter.get("description", "") if not description: - for line in body.strip().split('\n'): + for line in body.strip().split("\n"): line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): description = line break - + # Truncate to reasonable length if len(description) > MAX_DESCRIPTION_LENGTH: - description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..." - + description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..." + return description if description else None except Exception: return None @@ -315,26 +316,24 @@ def _load_category_description(category_dir: Path) -> Optional[str]: def skills_categories(verbose: bool = False, task_id: str = None) -> str: """ List available skill categories with descriptions (progressive disclosure tier 0). - + Returns category names and descriptions for efficient discovery before drilling down. Categories can have a DESCRIPTION.md file with a description frontmatter field or first paragraph to explain what skills are in that category. - + Args: verbose: If True, include skill counts per category (default: False, but currently always included) task_id: Optional task identifier (unused, for API consistency) - + Returns: JSON string with list of categories and their descriptions """ try: if not SKILLS_DIR.exists(): - return json.dumps({ - "success": True, - "categories": [], - "message": "No skills directory found." - }, ensure_ascii=False) - + return json.dumps( + {"success": True, "categories": [], "message": "No skills directory found."}, ensure_ascii=False + ) + category_dirs = {} for skill_md in SKILLS_DIR.rglob("SKILL.md"): category = _get_category_from_path(skill_md) @@ -342,121 +341,125 @@ def skills_categories(verbose: bool = False, task_id: str = None) -> str: category_dir = SKILLS_DIR / category if category not in category_dirs: category_dirs[category] = category_dir - + categories = [] for name in sorted(category_dirs.keys()): category_dir = category_dirs[name] description = _load_category_description(category_dir) skill_count = sum(1 for _ in category_dir.rglob("SKILL.md")) - + cat_entry = {"name": name, "skill_count": skill_count} if description: cat_entry["description"] = description categories.append(cat_entry) - - return json.dumps({ - "success": True, - "categories": categories, - "hint": "If a category is relevant to your task, use skills_list with that category to see available skills" - }, ensure_ascii=False) - + + return json.dumps( + { + "success": True, + "categories": categories, + "hint": "If a category is relevant to your task, use skills_list with that category to see available skills", + }, + ensure_ascii=False, + ) + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, ensure_ascii=False) + return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False) def skills_list(category: str = None, task_id: str = None) -> str: """ List all available skills (progressive disclosure tier 1 - minimal metadata). - - Returns only name + description to minimize token usage. Use skill_view() to + + Returns only name + description to minimize token usage. Use skill_view() to load full content, tags, related files, etc. - + Args: category: Optional category filter (e.g., "mlops") task_id: Optional task identifier (unused, for API consistency) - + Returns: JSON string with minimal skill info: name, description, category """ try: if not SKILLS_DIR.exists(): SKILLS_DIR.mkdir(parents=True, exist_ok=True) - return json.dumps({ - "success": True, - "skills": [], - "categories": [], - "message": "No skills found. Skills directory created at ~/.hermes/skills/" - }, ensure_ascii=False) - + return json.dumps( + { + "success": True, + "skills": [], + "categories": [], + "message": "No skills found. Skills directory created at ~/.hermes/skills/", + }, + ensure_ascii=False, + ) + # Find all skills all_skills = _find_all_skills() - + if not all_skills: - return json.dumps({ - "success": True, - "skills": [], - "categories": [], - "message": "No skills found in skills/ directory." - }, ensure_ascii=False) - + return json.dumps( + {"success": True, "skills": [], "categories": [], "message": "No skills found in skills/ directory."}, + ensure_ascii=False, + ) + # Filter by category if specified if category: all_skills = [s for s in all_skills if s.get("category") == category] - + # Sort by category then name all_skills.sort(key=lambda s: (s.get("category") or "", s["name"])) - + # Extract unique categories categories = sorted(set(s.get("category") for s in all_skills if s.get("category"))) - - return json.dumps({ - "success": True, - "skills": all_skills, - "categories": categories, - "count": len(all_skills), - "hint": "Use skill_view(name) to see full content, tags, and linked files" - }, ensure_ascii=False) - + + return json.dumps( + { + "success": True, + "skills": all_skills, + "categories": categories, + "count": len(all_skills), + "hint": "Use skill_view(name) to see full content, tags, and linked files", + }, + ensure_ascii=False, + ) + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, ensure_ascii=False) + return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False) def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: """ View the content of a skill or a specific file within a skill directory. - + Args: name: Name or path of the skill (e.g., "axolotl" or "03-fine-tuning/axolotl") file_path: Optional path to a specific file within the skill (e.g., "references/api.md") task_id: Optional task identifier (unused, for API consistency) - + Returns: JSON string with skill content or error message """ try: if not SKILLS_DIR.exists(): - return json.dumps({ - "success": False, - "error": "Skills directory does not exist yet. It will be created on first install." - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "error": "Skills directory does not exist yet. It will be created on first install.", + }, + ensure_ascii=False, + ) + skill_dir = None skill_md = None - + # Try direct path first (e.g., "mlops/axolotl") direct_path = SKILLS_DIR / name if direct_path.is_dir() and (direct_path / "SKILL.md").exists(): skill_dir = direct_path skill_md = direct_path / "SKILL.md" - elif direct_path.with_suffix('.md').exists(): - skill_md = direct_path.with_suffix('.md') - + elif direct_path.with_suffix(".md").exists(): + skill_md = direct_path.with_suffix(".md") + # Search by directory name if not skill_md: for found_skill_md in SKILLS_DIR.rglob("SKILL.md"): @@ -464,64 +467,70 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: skill_dir = found_skill_md.parent skill_md = found_skill_md break - + # Legacy: flat .md files if not skill_md: for found_md in SKILLS_DIR.rglob(f"{name}.md"): if found_md.name != "SKILL.md": skill_md = found_md break - + if not skill_md or not skill_md.exists(): # List available skills in error message all_skills = _find_all_skills() available = [s["name"] for s in all_skills[:20]] # Limit to 20 - return json.dumps({ - "success": False, - "error": f"Skill '{name}' not found.", - "available_skills": available, - "hint": "Use skills_list to see all available skills" - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "error": f"Skill '{name}' not found.", + "available_skills": available, + "hint": "Use skills_list to see all available skills", + }, + ensure_ascii=False, + ) + # If a specific file path is requested, read that instead if file_path and skill_dir: # Security: Prevent path traversal attacks normalized_path = Path(file_path) if ".." in normalized_path.parts: - return json.dumps({ - "success": False, - "error": "Path traversal ('..') is not allowed.", - "hint": "Use a relative path within the skill directory" - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "error": "Path traversal ('..') is not allowed.", + "hint": "Use a relative path within the skill directory", + }, + ensure_ascii=False, + ) + target_file = skill_dir / file_path - + # Security: Verify resolved path is still within skill directory try: resolved = target_file.resolve() skill_dir_resolved = skill_dir.resolve() if not resolved.is_relative_to(skill_dir_resolved): - return json.dumps({ - "success": False, - "error": "Path escapes skill directory boundary.", - "hint": "Use a relative path within the skill directory" - }, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "Path escapes skill directory boundary.", + "hint": "Use a relative path within the skill directory", + }, + ensure_ascii=False, + ) except (OSError, ValueError): - return json.dumps({ - "success": False, - "error": f"Invalid file path: '{file_path}'", - "hint": "Use a valid relative path within the skill directory" - }, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": f"Invalid file path: '{file_path}'", + "hint": "Use a valid relative path within the skill directory", + }, + ensure_ascii=False, + ) if not target_file.exists(): # List available files in the skill directory, organized by type - available_files = { - "references": [], - "templates": [], - "assets": [], - "scripts": [], - "other": [] - } - + available_files = {"references": [], "templates": [], "assets": [], "scripts": [], "other": []} + # Scan for all readable files for f in skill_dir.rglob("*"): if f.is_file() and f.name != "SKILL.md": @@ -534,82 +543,85 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: available_files["assets"].append(rel) elif rel.startswith("scripts/"): available_files["scripts"].append(rel) - elif f.suffix in ['.md', '.py', '.yaml', '.yml', '.json', '.tex', '.sh']: + elif f.suffix in [".md", ".py", ".yaml", ".yml", ".json", ".tex", ".sh"]: available_files["other"].append(rel) - + # Remove empty categories available_files = {k: v for k, v in available_files.items() if v} - - return json.dumps({ - "success": False, - "error": f"File '{file_path}' not found in skill '{name}'.", - "available_files": available_files, - "hint": "Use one of the available file paths listed above" - }, ensure_ascii=False) - + + return json.dumps( + { + "success": False, + "error": f"File '{file_path}' not found in skill '{name}'.", + "available_files": available_files, + "hint": "Use one of the available file paths listed above", + }, + ensure_ascii=False, + ) + # Read the file content try: - content = target_file.read_text(encoding='utf-8') + content = target_file.read_text(encoding="utf-8") except UnicodeDecodeError: # Binary file - return info about it instead - return json.dumps({ - "success": True, - "name": name, - "file": file_path, - "content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]", - "is_binary": True - }, ensure_ascii=False) - - return json.dumps({ - "success": True, - "name": name, - "file": file_path, - "content": content, - "file_type": target_file.suffix - }, ensure_ascii=False) - + return json.dumps( + { + "success": True, + "name": name, + "file": file_path, + "content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]", + "is_binary": True, + }, + ensure_ascii=False, + ) + + return json.dumps( + {"success": True, "name": name, "file": file_path, "content": content, "file_type": target_file.suffix}, + ensure_ascii=False, + ) + # Read the main skill content - content = skill_md.read_text(encoding='utf-8') + content = skill_md.read_text(encoding="utf-8") frontmatter, body = _parse_frontmatter(content) - + # Get reference, template, asset, and script files if this is a directory-based skill reference_files = [] template_files = [] asset_files = [] script_files = [] - + if skill_dir: references_dir = skill_dir / "references" if references_dir.exists(): reference_files = [str(f.relative_to(skill_dir)) for f in references_dir.glob("*.md")] - + templates_dir = skill_dir / "templates" if templates_dir.exists(): - for ext in ['*.md', '*.py', '*.yaml', '*.yml', '*.json', '*.tex', '*.sh']: + for ext in ["*.md", "*.py", "*.yaml", "*.yml", "*.json", "*.tex", "*.sh"]: template_files.extend([str(f.relative_to(skill_dir)) for f in templates_dir.rglob(ext)]) - + # assets/ — agentskills.io standard directory for supplementary files assets_dir = skill_dir / "assets" if assets_dir.exists(): for f in assets_dir.rglob("*"): if f.is_file(): asset_files.append(str(f.relative_to(skill_dir))) - + scripts_dir = skill_dir / "scripts" if scripts_dir.exists(): - for ext in ['*.py', '*.sh', '*.bash', '*.js', '*.ts', '*.rb']: + for ext in ["*.py", "*.sh", "*.bash", "*.js", "*.ts", "*.rb"]: script_files.extend([str(f.relative_to(skill_dir)) for f in scripts_dir.glob(ext)]) - + # Read tags/related_skills with backward compat: # Check metadata.hermes.* first (agentskills.io convention), fall back to top-level hermes_meta = {} - metadata = frontmatter.get('metadata') + metadata = frontmatter.get("metadata") if isinstance(metadata, dict): - hermes_meta = metadata.get('hermes', {}) or {} - - tags = _parse_tags(hermes_meta.get('tags') or frontmatter.get('tags', '')) - related_skills = _parse_tags(hermes_meta.get('related_skills') or frontmatter.get('related_skills', '')) - + hermes_meta = metadata.get("hermes", {}) or {} + + tags = _parse_tags(hermes_meta.get("tags") or frontmatter.get("tags", "")) + related_skills = _parse_tags(hermes_meta.get("related_skills") or frontmatter.get("related_skills", "")) + # Build linked files structure for clear discovery linked_files = {} if reference_files: @@ -620,34 +632,33 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: linked_files["assets"] = asset_files if script_files: linked_files["scripts"] = script_files - + rel_path = str(skill_md.relative_to(SKILLS_DIR)) - + result = { "success": True, - "name": frontmatter.get('name', skill_md.stem if not skill_dir else skill_dir.name), - "description": frontmatter.get('description', ''), + "name": frontmatter.get("name", skill_md.stem if not skill_dir else skill_dir.name), + "description": frontmatter.get("description", ""), "tags": tags, "related_skills": related_skills, "content": content, "path": rel_path, "linked_files": linked_files if linked_files else None, - "usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'" if linked_files else None + "usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'" + if linked_files + else None, } - + # Surface agentskills.io optional fields when present - if frontmatter.get('compatibility'): - result["compatibility"] = frontmatter['compatibility'] + if frontmatter.get("compatibility"): + result["compatibility"] = frontmatter["compatibility"] if isinstance(metadata, dict): result["metadata"] = metadata - + return json.dumps(result, ensure_ascii=False) - + except Exception as e: - return json.dumps({ - "success": False, - "error": str(e) - }, ensure_ascii=False) + return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False) # Tool description for model_tools.py @@ -669,7 +680,7 @@ if __name__ == "__main__": """Test the skills tool""" print("🎯 Skills Tool Test") print("=" * 60) - + # Test listing skills print("\n📋 Listing all skills:") result = json.loads(skills_list()) @@ -678,12 +689,12 @@ if __name__ == "__main__": print(f"Categories: {result.get('categories', [])}") print("\nFirst 10 skills:") for skill in result["skills"][:10]: - cat = f"[{skill['category']}] " if skill.get('category') else "" - refs = f" (+{len(skill['reference_files'])} refs)" if skill.get('reference_files') else "" + cat = f"[{skill['category']}] " if skill.get("category") else "" + refs = f" (+{len(skill['reference_files'])} refs)" if skill.get("reference_files") else "" print(f" • {cat}{skill['name']}: {skill['description'][:60]}...{refs}") else: print(f"Error: {result['error']}") - + # Test viewing a skill print("\n📖 Viewing skill 'axolotl':") result = json.loads(skill_view("axolotl")) @@ -691,11 +702,11 @@ if __name__ == "__main__": print(f"Name: {result['name']}") print(f"Description: {result.get('description', 'N/A')[:100]}...") print(f"Content length: {len(result['content'])} chars") - if result.get('reference_files'): + if result.get("reference_files"): print(f"Reference files: {result['reference_files']}") else: print(f"Error: {result['error']}") - + # Test viewing a reference file print("\n📄 Viewing reference file 'axolotl/references/dataset-formats.md':") result = json.loads(skill_view("axolotl", "references/dataset-formats.md")) @@ -717,14 +728,9 @@ SKILLS_LIST_SCHEMA = { "description": "List available skills (name + description). Use skill_view(name) to load full content.", "parameters": { "type": "object", - "properties": { - "category": { - "type": "string", - "description": "Optional category filter to narrow results" - } - }, - "required": [] - } + "properties": {"category": {"type": "string", "description": "Optional category filter to narrow results"}}, + "required": [], + }, } SKILL_VIEW_SCHEMA = { @@ -733,17 +739,14 @@ SKILL_VIEW_SCHEMA = { "parameters": { "type": "object", "properties": { - "name": { - "type": "string", - "description": "The skill name (use skills_list to see available skills)" - }, + "name": {"type": "string", "description": "The skill name (use skills_list to see available skills)"}, "file_path": { "type": "string", - "description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content." - } + "description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content.", + }, }, - "required": ["name"] - } + "required": ["name"], + }, } registry.register( diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index e123262c5e..c70b079236 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -26,20 +26,22 @@ Usage: result = terminal_tool("python server.py", background=True) """ +import atexit import json import logging import os -import signal -import sys -import time -import threading -import atexit import shutil -import subprocess -import tempfile -import uuid +import sys +import threading +import time from pathlib import Path -from typing import Optional, Dict, Any +from typing import Any + +from tools.interrupt import ( + _interrupt_event, # noqa: F401 — re-exported to environments/local.py + is_interrupted, # noqa: F401 — re-exported +) +from tools.interrupt import set_interrupt as set_interrupt_event # noqa: F401 — re-exported logger = logging.getLogger(__name__) @@ -49,7 +51,6 @@ logger = logging.getLogger(__name__) # The terminal tool polls this during command execution so it can kill # long-running subprocesses immediately instead of blocking until timeout. # --------------------------------------------------------------------------- -from tools.interrupt import set_interrupt as set_interrupt_event, is_interrupted, _interrupt_event # Add mini-swe-agent to path if not installed @@ -65,7 +66,6 @@ if mini_swe_path.exists(): # Singularity helpers (scratch dir, SIF cache) now live in tools/environments/singularity.py from tools.environments.singularity import _get_scratch_dir - # Disk usage warning threshold (in GB) DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "500")) @@ -73,28 +73,32 @@ DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "5 def _check_disk_usage_warning(): """Check if total disk usage exceeds warning threshold.""" scratch_dir = _get_scratch_dir() - + try: # Get total size of hermes directories total_bytes = 0 import glob + for path in glob.glob(str(scratch_dir / "hermes-*")): - for f in Path(path).rglob('*'): + for f in Path(path).rglob("*"): if f.is_file(): try: total_bytes += f.stat().st_size except OSError: pass - - total_gb = total_bytes / (1024 ** 3) - + + total_gb = total_bytes / (1024**3) + if total_gb > DISK_USAGE_WARNING_THRESHOLD_GB: - logger.warning("Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().", - total_gb, DISK_USAGE_WARNING_THRESHOLD_GB) + logger.warning( + "Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().", + total_gb, + DISK_USAGE_WARNING_THRESHOLD_GB, + ) return True - + return False - except Exception as e: + except Exception: return False @@ -121,59 +125,59 @@ def set_approval_callback(cb): global _approval_callback _approval_callback = cb + # ============================================================================= # Dangerous Command Approval System # ============================================================================= # Dangerous command detection + approval now consolidated in tools/approval.py from tools.approval import ( - detect_dangerous_command as _detect_dangerous_command, check_dangerous_command as _check_dangerous_command_impl, - load_permanent_allowlist as _load_permanent_allowlist, - DANGEROUS_PATTERNS, ) def _check_dangerous_command(command: str, env_type: str) -> dict: """Delegate to the consolidated approval module, passing the CLI callback.""" - return _check_dangerous_command_impl(command, env_type, - approval_callback=_approval_callback) + return _check_dangerous_command_impl(command, env_type, approval_callback=_approval_callback) def _handle_sudo_failure(output: str, env_type: str) -> str: """ Check for sudo failure and add helpful message for messaging contexts. - + Returns enhanced output if sudo failed in messaging context, else original. """ is_gateway = os.getenv("HERMES_GATEWAY_SESSION") - + if not is_gateway: return output - + # Check for sudo failure indicators sudo_failures = [ "sudo: a password is required", "sudo: no tty present", "sudo: a terminal is required", ] - + for failure in sudo_failures: if failure in output: - return output + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine." - + return ( + output + + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine." + ) + return output def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: """ Prompt user for sudo password with timeout. - + Returns the password if entered, or empty string if: - User presses Enter without input (skip) - Timeout expires (45s default) - Any error occurs - + Only works in interactive mode (HERMES_INTERACTIVE=1). If a _sudo_password_callback is registered (by the CLI), delegates to it so the prompt integrates with prompt_toolkit's UI. Otherwise reads @@ -181,7 +185,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: """ import sys import time as time_module - + # Use the registered callback when available (prompt_toolkit-compatible) if _sudo_password_callback is not None: try: @@ -190,13 +194,14 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: return "" result = {"password": None, "done": False} - + def read_password_thread(): """Read password from /dev/tty with echo disabled.""" tty_fd = None old_attrs = None try: import termios + tty_fd = os.open("/dev/tty", os.O_RDONLY) old_attrs = termios.tcgetattr(tty_fd) new_attrs = termios.tcgetattr(tty_fd) @@ -217,6 +222,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: if tty_fd is not None and old_attrs is not None: try: import termios as _termios + _termios.tcsetattr(tty_fd, _termios.TCSAFLUSH, old_attrs) except Exception: pass @@ -226,11 +232,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: except Exception: pass result["done"] = True - + try: os.environ["HERMES_SPINNER_PAUSE"] = "1" time_module.sleep(0.2) - + print() print("┌" + "─" * 58 + "┐") print("│ 🔐 SUDO PASSWORD REQUIRED" + " " * 30 + "│") @@ -241,11 +247,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: print("└" + "─" * 58 + "┘") print() print(" Password (hidden): ", end="", flush=True) - + password_thread = threading.Thread(target=read_password_thread, daemon=True) password_thread.start() password_thread.join(timeout=timeout_seconds) - + if result["done"]: password = result["password"] or "" print() # newline after hidden input @@ -262,7 +268,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: print() sys.stdout.flush() return "" - + except (EOFError, KeyboardInterrupt): print() print(" ⏭ Cancelled - continuing without sudo") @@ -281,29 +287,29 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str: def _transform_sudo_command(command: str) -> str: """ Transform sudo commands to use -S flag if SUDO_PASSWORD is available. - + This is a shared helper used by all execution environments to provide consistent sudo handling across local, SSH, and container environments. - + If SUDO_PASSWORD is set (via env, config, or interactive prompt): 'sudo apt install curl' -> password piped via sudo -S - + If SUDO_PASSWORD is not set and in interactive mode (HERMES_INTERACTIVE=1): Prompts user for password with 45s timeout, caches for session. - + If SUDO_PASSWORD is not set and NOT interactive: Command runs as-is (fails gracefully with "sudo: a password is required"). """ global _cached_sudo_password import re - + # Check if command even contains sudo - if not re.search(r'\bsudo\b', command): + if not re.search(r"\bsudo\b", command): return command # No sudo in command, return as-is - + # Try to get password from: env var -> session cache -> interactive prompt sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password - + if not sudo_password: # No password configured - check if we're in interactive mode if os.getenv("HERMES_INTERACTIVE"): @@ -311,30 +317,30 @@ def _transform_sudo_command(command: str) -> str: sudo_password = _prompt_for_sudo_password(timeout_seconds=45) if sudo_password: _cached_sudo_password = sudo_password # Cache for session - + if not sudo_password: return command # No password, let it fail gracefully - + def replace_sudo(match): # Replace 'sudo' with password-piped version # The -S flag makes sudo read password from stdin # The -p '' suppresses the password prompt # Use shlex.quote() to prevent shell injection via password content import shlex + return f"echo {shlex.quote(sudo_password)} | sudo -S -p ''" - + # Match 'sudo' at word boundaries (not 'visudo' or 'sudoers') # This handles: sudo, sudo -flag, etc. - return re.sub(r'\bsudo\b', replace_sudo, command) + return re.sub(r"\bsudo\b", replace_sudo, command) # Environment classes now live in tools/environments/ +from tools.environments.docker import DockerEnvironment as _DockerEnvironment from tools.environments.local import LocalEnvironment as _LocalEnvironment +from tools.environments.modal import ModalEnvironment as _ModalEnvironment from tools.environments.singularity import SingularityEnvironment as _SingularityEnvironment from tools.environments.ssh import SSHEnvironment as _SSHEnvironment -from tools.environments.docker import DockerEnvironment as _DockerEnvironment -from tools.environments.modal import ModalEnvironment as _ModalEnvironment - # Tool description for LLM TERMINAL_TOOL_DESCRIPTION = """Execute shell commands on a Linux environment. Filesystem persists between calls. @@ -356,10 +362,10 @@ Do NOT use vim/nano/interactive tools without pty=true — they hang without a p """ # Global state for environment lifecycle management -_active_environments: Dict[str, Any] = {} -_last_activity: Dict[str, float] = {} +_active_environments: dict[str, Any] = {} +_last_activity: dict[str, float] = {} _env_lock = threading.Lock() -_creation_locks: Dict[str, threading.Lock] = {} # Per-task locks for sandbox creation +_creation_locks: dict[str, threading.Lock] = {} # Per-task locks for sandbox creation _creation_locks_lock = threading.Lock() # Protects _creation_locks dict itself _cleanup_thread = None _cleanup_running = False @@ -372,10 +378,10 @@ _cleanup_running = False # # This is never exposed to the model -- only infrastructure code calls it. # Thread-safe because each task_id is unique per rollout. -_task_env_overrides: Dict[str, Dict[str, Any]] = {} +_task_env_overrides: dict[str, dict[str, Any]] = {} -def register_task_env_overrides(task_id: str, overrides: Dict[str, Any]): +def register_task_env_overrides(task_id: str, overrides: dict[str, Any]): """ Register environment overrides for a specific task/rollout. @@ -402,13 +408,14 @@ def clear_task_env_overrides(task_id: str): """ _task_env_overrides.pop(task_id, None) + # Configuration from environment variables -def _get_env_config() -> Dict[str, Any]: +def _get_env_config() -> dict[str, Any]: """Get terminal environment configuration from environment variables.""" # Default image with Python and Node.js for maximum compatibility default_image = "nikolaik/python-nodejs:python3.11-nodejs20" env_type = os.getenv("TERMINAL_ENV", "local") - + # Default cwd: local uses the host's current directory, everything # else starts in the user's home (~ resolves to whatever account # is running inside the container/remote). @@ -416,7 +423,7 @@ def _get_env_config() -> Dict[str, Any]: default_cwd = os.getcwd() else: default_cwd = "~" - + # Read TERMINAL_CWD but sanity-check it for container backends. # If the CWD looks like a host-local path that can't exist inside a # container/sandbox, fall back to the backend's own default. This @@ -426,9 +433,12 @@ def _get_env_config() -> Dict[str, Any]: if env_type in ("modal", "docker", "singularity", "daytona") and cwd: host_prefixes = ("/Users/", "C:\\", "C:/") if any(cwd.startswith(p) for p in host_prefixes) and cwd != default_cwd: - logger.info("Ignoring TERMINAL_CWD=%r for %s backend " - "(host path won't exist in sandbox). Using %r instead.", - cwd, env_type, default_cwd) + logger.info( + "Ignoring TERMINAL_CWD=%r for %s backend (host path won't exist in sandbox). Using %r instead.", + cwd, + env_type, + default_cwd, + ) cwd = default_cwd return { @@ -447,19 +457,25 @@ def _get_env_config() -> Dict[str, Any]: "ssh_key": os.getenv("TERMINAL_SSH_KEY", ""), # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) "container_cpu": float(os.getenv("TERMINAL_CONTAINER_CPU", "1")), - "container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB) - "container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB) + "container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB) + "container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB) "container_persistent": os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("true", "1", "yes"), "docker_volumes": json.loads(os.getenv("TERMINAL_DOCKER_VOLUMES", "[]")), } -def _create_environment(env_type: str, image: str, cwd: str, timeout: int, - ssh_config: dict = None, container_config: dict = None, - task_id: str = "default"): +def _create_environment( + env_type: str, + image: str, + cwd: str, + timeout: int, + ssh_config: dict = None, + container_config: dict = None, + task_id: str = "default", +): """ Create an execution environment from mini-swe-agent. - + Args: env_type: One of "local", "docker", "singularity", "modal", "daytona", "ssh" image: Docker/Singularity/Modal image name (ignored for local/ssh) @@ -468,7 +484,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_config: SSH connection config (for env_type="ssh") container_config: Resource config for container backends (cpu, memory, disk, persistent) task_id: Task identifier for environment reuse and snapshot keying - + Returns: Environment instance with execute() method """ @@ -481,22 +497,32 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, if env_type == "local": return _LocalEnvironment(cwd=cwd, timeout=timeout) - + elif env_type == "docker": return _DockerEnvironment( - image=image, cwd=cwd, timeout=timeout, - cpu=cpu, memory=memory, disk=disk, - persistent_filesystem=persistent, task_id=task_id, + image=image, + cwd=cwd, + timeout=timeout, + cpu=cpu, + memory=memory, + disk=disk, + persistent_filesystem=persistent, + task_id=task_id, volumes=volumes, ) - + elif env_type == "singularity": return _SingularityEnvironment( - image=image, cwd=cwd, timeout=timeout, - cpu=cpu, memory=memory, disk=disk, - persistent_filesystem=persistent, task_id=task_id, + image=image, + cwd=cwd, + timeout=timeout, + cpu=cpu, + memory=memory, + disk=disk, + persistent_filesystem=persistent, + task_id=task_id, ) - + elif env_type == "modal": sandbox_kwargs = {} if cpu > 0: @@ -505,20 +531,29 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, sandbox_kwargs["memory"] = memory if disk > 0: sandbox_kwargs["ephemeral_disk"] = disk - + return _ModalEnvironment( - image=image, cwd=cwd, timeout=timeout, + image=image, + cwd=cwd, + timeout=timeout, modal_sandbox_kwargs=sandbox_kwargs, - persistent_filesystem=persistent, task_id=task_id, + persistent_filesystem=persistent, + task_id=task_id, ) - + elif env_type == "daytona": # Lazy import so daytona SDK is only required when backend is selected. from tools.environments.daytona import DaytonaEnvironment as _DaytonaEnvironment + return _DaytonaEnvironment( - image=image, cwd=cwd, timeout=timeout, - cpu=int(cpu), memory=memory, disk=disk, - persistent_filesystem=persistent, task_id=task_id, + image=image, + cwd=cwd, + timeout=timeout, + cpu=int(cpu), + memory=memory, + disk=disk, + persistent_filesystem=persistent, + task_id=task_id, ) elif env_type == "ssh": @@ -534,7 +569,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ) else: - raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'") + raise ValueError( + f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'" + ) def _cleanup_inactive_envs(lifetime_seconds: int = 300): @@ -547,6 +584,7 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300): # background processes (their _last_activity gets refreshed to keep them alive). try: from tools.process_registry import process_registry + for task_id in list(_last_activity.keys()): if process_registry.has_active_processes(task_id): _last_activity[task_id] = current_time # Keep sandbox alive @@ -579,16 +617,17 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300): # ShellFileOperations from referencing a dead sandbox) try: from tools.file_tools import clear_file_ops_cache + clear_file_ops_cache(task_id) except ImportError: pass try: - if hasattr(env, 'cleanup'): + if hasattr(env, "cleanup"): env.cleanup() - elif hasattr(env, 'stop'): + elif hasattr(env, "stop"): env.stop() - elif hasattr(env, 'terminate'): + elif hasattr(env, "terminate"): env.terminate() logger.info("Cleaned up inactive environment for task: %s", task_id) @@ -640,27 +679,28 @@ def _stop_cleanup_thread(): pass -def get_active_environments_info() -> Dict[str, Any]: +def get_active_environments_info() -> dict[str, Any]: """Get information about currently active environments.""" info = { "count": len(_active_environments), "task_ids": list(_active_environments.keys()), "workdirs": {}, } - + # Calculate total disk usage (per-task to avoid double-counting) total_size = 0 - for task_id in _active_environments.keys(): + for task_id in _active_environments: scratch_dir = _get_scratch_dir() pattern = f"hermes-*{task_id[:8]}*" import glob + for path in glob.glob(str(scratch_dir / pattern)): try: - size = sum(f.stat().st_size for f in Path(path).rglob('*') if f.is_file()) + size = sum(f.stat().st_size for f in Path(path).rglob("*") if f.is_file()) total_size += size except OSError: pass - + info["total_disk_usage_mb"] = round(total_size / (1024 * 1024), 2) return info @@ -668,27 +708,28 @@ def get_active_environments_info() -> Dict[str, Any]: def cleanup_all_environments(): """Clean up ALL active environments. Use with caution.""" global _active_environments, _last_activity - + task_ids = list(_active_environments.keys()) cleaned = 0 - + for task_id in task_ids: try: cleanup_vm(task_id) cleaned += 1 except Exception as e: logger.error("Error cleaning %s: %s", task_id, e, exc_info=True) - + # Also clean any orphaned directories scratch_dir = _get_scratch_dir() import glob + for path in glob.glob(str(scratch_dir / "hermes-*")): try: shutil.rmtree(path, ignore_errors=True) logger.info("Removed orphaned: %s", path) except OSError: pass - + if cleaned > 0: logger.info("Cleaned %d environments", cleaned) return cleaned @@ -713,6 +754,7 @@ def cleanup_vm(task_id: str): # Invalidate stale file_ops cache entry try: from tools.file_tools import clear_file_ops_cache + clear_file_ops_cache(task_id) except ImportError: pass @@ -721,11 +763,11 @@ def cleanup_vm(task_id: str): return try: - if hasattr(env, 'cleanup'): + if hasattr(env, "cleanup"): env.cleanup() - elif hasattr(env, 'stop'): + elif hasattr(env, "stop"): env.stop() - elif hasattr(env, 'terminate'): + elif hasattr(env, "terminate"): env.terminate() logger.info("Manually cleaned up environment for task: %s", task_id) @@ -746,17 +788,18 @@ def _atexit_cleanup(): logger.info("Shutting down %d remaining sandbox(es)...", count) cleanup_all_environments() + atexit.register(_atexit_cleanup) def terminal_tool( command: str, background: bool = False, - timeout: Optional[int] = None, - task_id: Optional[str] = None, + timeout: int | None = None, + task_id: str | None = None, force: bool = False, - workdir: Optional[str] = None, - check_interval: Optional[int] = None, + workdir: str | None = None, + check_interval: int | None = None, pty: bool = False, ) -> str: """ @@ -784,7 +827,7 @@ def terminal_tool( # With custom timeout >>> result = terminal_tool(command="long_task.sh", timeout=300) - + # Force run after user confirmation # Note: force parameter is internal only, not exposed to model API """ @@ -801,7 +844,7 @@ def terminal_tool( # Check per-task overrides (set by environments like TerminalBench2Env) # before falling back to global env var config overrides = _task_env_overrides.get(effective_task_id, {}) - + # Select image based on env type, with per-task override support if env_type == "docker": image = overrides.get("docker_image") or config["docker_image"] @@ -882,12 +925,15 @@ def terminal_tool( task_id=effective_task_id, ) except ImportError as e: - return json.dumps({ - "output": "", - "exit_code": -1, - "error": f"Terminal tool disabled: mini-swe-agent not available ({e})", - "status": "disabled" - }, ensure_ascii=False) + return json.dumps( + { + "output": "", + "exit_code": -1, + "error": f"Terminal tool disabled: mini-swe-agent not available ({e})", + "status": "disabled", + }, + ensure_ascii=False, + ) with _env_lock: _active_environments[effective_task_id] = new_env @@ -902,27 +948,33 @@ def terminal_tool( if not approval["approved"]: # Check if this is an approval_required (gateway ask mode) if approval.get("status") == "approval_required": - return json.dumps({ - "output": "", - "exit_code": -1, - "error": approval.get("message", "Waiting for user approval"), - "status": "approval_required", - "command": approval.get("command", command), - "description": approval.get("description", "dangerous command"), - "pattern_key": approval.get("pattern_key", ""), - }, ensure_ascii=False) + return json.dumps( + { + "output": "", + "exit_code": -1, + "error": approval.get("message", "Waiting for user approval"), + "status": "approval_required", + "command": approval.get("command", command), + "description": approval.get("description", "dangerous command"), + "pattern_key": approval.get("pattern_key", ""), + }, + ensure_ascii=False, + ) # Command was blocked - include the pattern category so the caller knows why desc = approval.get("description", "potentially dangerous operation") fallback_msg = ( f"Command denied: matches '{desc}' pattern. " "Use the approval prompt to allow it, or rephrase the command." ) - return json.dumps({ - "output": "", - "exit_code": -1, - "error": approval.get("message", fallback_msg), - "status": "blocked" - }, ensure_ascii=False) + return json.dumps( + { + "output": "", + "exit_code": -1, + "error": approval.get("message", fallback_msg), + "status": "blocked", + }, + ensure_ascii=False, + ) # Prepare command for execution if background: @@ -940,7 +992,7 @@ def terminal_tool( cwd=effective_cwd, task_id=effective_task_id, session_key=session_key, - env_vars=env.env if hasattr(env, 'env') else None, + env_vars=env.env if hasattr(env, "env") else None, use_pty=pty, ) else: @@ -964,38 +1016,36 @@ def terminal_tool( max_timeout = effective_timeout if timeout and timeout > max_timeout: result_data["timeout_note"] = ( - f"Requested timeout {timeout}s was clamped to " - f"configured limit of {max_timeout}s" + f"Requested timeout {timeout}s was clamped to configured limit of {max_timeout}s" ) # Register check_interval watcher (gateway picks this up after agent run) if check_interval and background: effective_interval = max(30, check_interval) if check_interval < 30: - result_data["check_interval_note"] = ( - f"Requested {check_interval}s raised to minimum 30s" - ) - process_registry.pending_watchers.append({ - "session_id": proc_session.id, - "check_interval": effective_interval, - "session_key": session_key, - "platform": os.getenv("HERMES_SESSION_PLATFORM", ""), - "chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""), - }) + result_data["check_interval_note"] = f"Requested {check_interval}s raised to minimum 30s" + process_registry.pending_watchers.append( + { + "session_id": proc_session.id, + "check_interval": effective_interval, + "session_key": session_key, + "platform": os.getenv("HERMES_SESSION_PLATFORM", ""), + "chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""), + } + ) return json.dumps(result_data, ensure_ascii=False) except Exception as e: - return json.dumps({ - "output": "", - "exit_code": -1, - "error": f"Failed to start background process: {str(e)}" - }, ensure_ascii=False) + return json.dumps( + {"output": "", "exit_code": -1, "error": f"Failed to start background process: {str(e)}"}, + ensure_ascii=False, + ) else: # Run foreground command with retry logic max_retries = 3 retry_count = 0 result = None - + while retry_count <= max_retries: try: execute_kwargs = {"timeout": effective_timeout} @@ -1005,39 +1055,61 @@ def terminal_tool( except Exception as e: error_str = str(e).lower() if "timeout" in error_str: - return json.dumps({ - "output": "", - "exit_code": 124, - "error": f"Command timed out after {effective_timeout} seconds" - }, ensure_ascii=False) - + return json.dumps( + { + "output": "", + "exit_code": 124, + "error": f"Command timed out after {effective_timeout} seconds", + }, + ensure_ascii=False, + ) + # Retry on transient errors if retry_count < max_retries: retry_count += 1 - wait_time = 2 ** retry_count - logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) + wait_time = 2**retry_count + logger.warning( + "Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s", + wait_time, + retry_count, + max_retries, + command[:200], + type(e).__name__, + e, + effective_task_id, + env_type, + ) time.sleep(wait_time) continue - - logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s", - max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type) - return json.dumps({ - "output": "", - "exit_code": -1, - "error": f"Command execution failed: {type(e).__name__}: {str(e)}" - }, ensure_ascii=False) - + + logger.error( + "Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s", + max_retries, + command[:200], + type(e).__name__, + e, + effective_task_id, + env_type, + ) + return json.dumps( + { + "output": "", + "exit_code": -1, + "error": f"Command execution failed: {type(e).__name__}: {str(e)}", + }, + ensure_ascii=False, + ) + # Got a result break - + # Extract output output = result.get("output", "") returncode = result.get("returncode", 0) - + # Add helpful message for sudo failures in messaging context output = _handle_sudo_failure(output, env_type) - + # Truncate output if too long, keeping both head and tail MAX_OUTPUT_CHARS = 50000 if len(output) > MAX_OUTPUT_CHARS: @@ -1045,65 +1117,56 @@ def terminal_tool( tail_chars = MAX_OUTPUT_CHARS - head_chars # 60% tail (most recent/relevant output) omitted = len(output) - head_chars - tail_chars truncated_notice = ( - f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted " - f"out of {len(output)} total] ...\n\n" + f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted out of {len(output)} total] ...\n\n" ) output = output[:head_chars] + truncated_notice + output[-tail_chars:] # Redact secrets from command output (catches env/printenv leaking keys) from agent.redact import redact_sensitive_text + output = redact_sensitive_text(output.strip()) if output else "" - return json.dumps({ - "output": output, - "exit_code": returncode, - "error": None - }, ensure_ascii=False) + return json.dumps({"output": output, "exit_code": returncode, "error": None}, ensure_ascii=False) except Exception as e: - return json.dumps({ - "output": "", - "exit_code": -1, - "error": f"Failed to execute command: {str(e)}", - "status": "error" - }, ensure_ascii=False) + return json.dumps( + {"output": "", "exit_code": -1, "error": f"Failed to execute command: {str(e)}", "status": "error"}, + ensure_ascii=False, + ) def check_terminal_requirements() -> bool: """Check if all requirements for the terminal tool are met.""" config = _get_env_config() env_type = config["env_type"] - + try: if env_type == "local": - from minisweagent.environments.local import LocalEnvironment return True elif env_type == "docker": - from minisweagent.environments.docker import DockerEnvironment # Check if docker is available import subprocess + result = subprocess.run(["docker", "version"], capture_output=True, timeout=5) return result.returncode == 0 elif env_type == "singularity": - from minisweagent.environments.singularity import SingularityEnvironment + import shutil + # Check if singularity/apptainer is available import subprocess - import shutil + executable = shutil.which("apptainer") or shutil.which("singularity") if executable: result = subprocess.run([executable, "--version"], capture_output=True, timeout=5) return result.returncode == 0 return False elif env_type == "ssh": - from tools.environments.ssh import SSHEnvironment # Check that host and user are configured return bool(config.get("ssh_host")) and bool(config.get("ssh_user")) elif env_type == "modal": - from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment # Check for modal token return os.getenv("MODAL_TOKEN_ID") is not None or Path.home().joinpath(".modal.toml").exists() elif env_type == "daytona": - from daytona import Daytona return os.getenv("DAYTONA_API_KEY") is not None else: return False @@ -1116,9 +1179,9 @@ if __name__ == "__main__": # Simple test when run directly print("Terminal Tool Module (mini-swe-agent backend)") print("=" * 50) - + config = _get_env_config() - print(f"\nCurrent Configuration:") + print("\nCurrent Configuration:") print(f" Environment type: {config['env_type']}") print(f" Docker image: {config['docker_image']}") print(f" Modal image: {config['modal_image']}") @@ -1165,37 +1228,34 @@ TERMINAL_SCHEMA = { "parameters": { "type": "object", "properties": { - "command": { - "type": "string", - "description": "The command to execute on the VM" - }, + "command": {"type": "string", "description": "The command to execute on the VM"}, "background": { "type": "boolean", "description": "ONLY for servers/watchers that never exit. For scripts, builds, installs — use foreground with timeout instead (it returns instantly when done).", - "default": False + "default": False, }, "timeout": { "type": "integer", "description": "Max seconds to wait (default: 180). Returns INSTANTLY when command finishes — set high for long tasks, you won't wait unnecessarily.", - "minimum": 1 + "minimum": 1, }, "workdir": { "type": "string", - "description": "Working directory for this command (absolute path). Defaults to the session working directory." + "description": "Working directory for this command (absolute path). Defaults to the session working directory.", }, "check_interval": { "type": "integer", "description": "Seconds between automatic status checks for background processes (gateway/messaging only, minimum 30). When set, I'll proactively report progress.", - "minimum": 30 + "minimum": 30, }, "pty": { "type": "boolean", "description": "Run in pseudo-terminal (PTY) mode for interactive CLI tools like Codex, Claude Code, or Python REPL. Only works with local and SSH backends. Default: false.", - "default": False - } + "default": False, + }, }, - "required": ["command"] - } + "required": ["command"], + }, } diff --git a/tools/todo_tool.py b/tools/todo_tool.py index a4853ac3b3..970b372bf5 100644 --- a/tools/todo_tool.py +++ b/tools/todo_tool.py @@ -15,8 +15,7 @@ Design: """ import json -from typing import Dict, Any, List, Optional - +from typing import Any # Valid status values for todo items VALID_STATUSES = {"pending", "in_progress", "completed", "cancelled"} @@ -33,9 +32,9 @@ class TodoStore: """ def __init__(self): - self._items: List[Dict[str, str]] = [] + self._items: list[dict[str, str]] = [] - def write(self, todos: List[Dict[str, Any]], merge: bool = False) -> List[Dict[str, str]]: + def write(self, todos: list[dict[str, Any]], merge: bool = False) -> list[dict[str, str]]: """ Write todos. Returns the full current list after writing. @@ -79,7 +78,7 @@ class TodoStore: self._items = rebuilt return self.read() - def read(self) -> List[Dict[str, str]]: + def read(self) -> list[dict[str, str]]: """Return a copy of the current list.""" return [item.copy() for item in self._items] @@ -87,7 +86,7 @@ class TodoStore: """Check if there are any items in the list.""" return len(self._items) > 0 - def format_for_injection(self) -> Optional[str]: + def format_for_injection(self) -> str | None: """ Render the todo list for post-compression injection. @@ -113,7 +112,7 @@ class TodoStore: return "\n".join(lines) @staticmethod - def _validate(item: Dict[str, Any]) -> Dict[str, str]: + def _validate(item: dict[str, Any]) -> dict[str, str]: """ Validate and normalize a todo item. @@ -136,9 +135,9 @@ class TodoStore: def todo_tool( - todos: Optional[List[Dict[str, Any]]] = None, + todos: list[dict[str, Any]] | None = None, merge: bool = False, - store: Optional[TodoStore] = None, + store: TodoStore | None = None, ) -> str: """ Single entry point for the todo tool. Reads or writes depending on params. @@ -165,16 +164,19 @@ def todo_tool( completed = sum(1 for i in items if i["status"] == "completed") cancelled = sum(1 for i in items if i["status"] == "cancelled") - return json.dumps({ - "todos": items, - "summary": { - "total": len(items), - "pending": pending, - "in_progress": in_progress, - "completed": completed, - "cancelled": cancelled, + return json.dumps( + { + "todos": items, + "summary": { + "total": len(items), + "pending": pending, + "in_progress": in_progress, + "completed": completed, + "cancelled": cancelled, + }, }, - }, ensure_ascii=False) + ensure_ascii=False, + ) def check_todo_requirements() -> bool: @@ -214,34 +216,27 @@ TODO_SCHEMA = { "items": { "type": "object", "properties": { - "id": { - "type": "string", - "description": "Unique item identifier" - }, - "content": { - "type": "string", - "description": "Task description" - }, + "id": {"type": "string", "description": "Unique item identifier"}, + "content": {"type": "string", "description": "Task description"}, "status": { "type": "string", "enum": ["pending", "in_progress", "completed", "cancelled"], - "description": "Current status" - } + "description": "Current status", + }, }, - "required": ["id", "content", "status"] - } + "required": ["id", "content", "status"], + }, }, "merge": { "type": "boolean", "description": ( - "true: update existing items by id, add new ones. " - "false (default): replace the entire list." + "true: update existing items by id, add new ones. false (default): replace the entire list." ), - "default": False - } + "default": False, + }, }, - "required": [] - } + "required": [], + }, } @@ -253,6 +248,7 @@ registry.register( toolset="todo", schema=TODO_SCHEMA, handler=lambda args, **kw: todo_tool( - todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")), + todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store") + ), check_fn=check_todo_requirements, ) diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index 8e26e0941b..e53c3a1042 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -24,7 +24,7 @@ Usage: import logging import os from pathlib import Path -from typing import Optional, Dict, Any +from typing import Any logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", MAX_FILE_SIZE = 25 * 1024 * 1024 -def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]: +def transcribe_audio(file_path: str, model: str | None = None) -> dict[str, Any]: """ Transcribe an audio file using OpenAI's Whisper API. @@ -65,7 +65,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A } audio_path = Path(file_path) - + # Validate file exists if not audio_path.exists(): return { @@ -73,14 +73,14 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A "transcript": "", "error": f"Audio file not found: {file_path}", } - + if not audio_path.is_file(): return { "success": False, "transcript": "", "error": f"Path is not a file: {file_path}", } - + # Validate file extension if audio_path.suffix.lower() not in SUPPORTED_FORMATS: return { @@ -88,7 +88,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A "transcript": "", "error": f"Unsupported file format: {audio_path.suffix}. Supported formats: {', '.join(sorted(SUPPORTED_FORMATS))}", } - + # Validate file size try: file_size = audio_path.stat().st_size @@ -96,7 +96,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A return { "success": False, "transcript": "", - "error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024)}MB)", + "error": f"File too large: {file_size / (1024 * 1024):.1f}MB (max {MAX_FILE_SIZE / (1024 * 1024)}MB)", } except OSError as e: logger.error("Failed to get file size for %s: %s", file_path, e, exc_info=True) @@ -111,7 +111,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A model = DEFAULT_STT_MODEL try: - from openai import OpenAI, APIError, APIConnectionError, APITimeoutError + from openai import APIConnectionError, APIError, APITimeoutError, OpenAI client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1") diff --git a/tools/tts_tool.py b/tools/tts_tool.py index 8e8f5e928f..4498c84247 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -27,9 +27,8 @@ import logging import os import shutil import subprocess -import tempfile from pathlib import Path -from typing import Dict, Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -38,12 +37,14 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- try: import edge_tts + _HAS_EDGE_TTS = True except ImportError: _HAS_EDGE_TTS = False try: from elevenlabs.client import ElevenLabs + _HAS_ELEVENLABS = True except ImportError: _HAS_ELEVENLABS = False @@ -51,6 +52,7 @@ except ImportError: # openai is a core dependency, but guard anyway try: from openai import OpenAI as OpenAIClient + _HAS_OPENAI = True except ImportError: _HAS_OPENAI = False @@ -72,7 +74,7 @@ MAX_TEXT_LENGTH = 4000 # =========================================================================== # Config loader -- reads tts: section from ~/.hermes/config.yaml # =========================================================================== -def _load_tts_config() -> Dict[str, Any]: +def _load_tts_config() -> dict[str, Any]: """ Load TTS configuration from ~/.hermes/config.yaml. @@ -81,13 +83,14 @@ def _load_tts_config() -> Dict[str, Any]: """ try: from hermes_cli.config import load_config + config = load_config() return config.get("tts", {}) except Exception: return {} -def _get_provider(tts_config: Dict[str, Any]) -> str: +def _get_provider(tts_config: dict[str, Any]) -> str: """Get the configured TTS provider name.""" return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip() @@ -100,7 +103,7 @@ def _has_ffmpeg() -> bool: return shutil.which("ffmpeg") is not None -def _convert_to_opus(mp3_path: str) -> Optional[str]: +def _convert_to_opus(mp3_path: str) -> str | None: """ Convert an MP3 file to OGG Opus format for Telegram voice bubbles. @@ -116,9 +119,9 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]: ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg" try: subprocess.run( - ["ffmpeg", "-i", mp3_path, "-acodec", "libopus", - "-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"], - capture_output=True, timeout=30, + ["ffmpeg", "-i", mp3_path, "-acodec", "libopus", "-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"], + capture_output=True, + timeout=30, ) if os.path.exists(ogg_path) and os.path.getsize(ogg_path) > 0: return ogg_path @@ -130,7 +133,7 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]: # =========================================================================== # Provider: Edge TTS (free) # =========================================================================== -async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str: +async def _generate_edge_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str: """ Generate audio using Edge TTS. @@ -153,7 +156,7 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, # =========================================================================== # Provider: ElevenLabs (premium) # =========================================================================== -def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]) -> str: +def _generate_elevenlabs(text: str, output_path: str, tts_config: dict[str, Any]) -> str: """ Generate audio using ElevenLabs. @@ -198,7 +201,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any] # =========================================================================== # Provider: OpenAI TTS # =========================================================================== -def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str: +def _generate_openai_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str: """ Generate audio using OpenAI TTS. @@ -241,7 +244,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any] # =========================================================================== def text_to_speech_tool( text: str, - output_path: Optional[str] = None, + output_path: str | None = None, ) -> str: """ Convert text to speech audio. @@ -276,7 +279,7 @@ def text_to_speech_tool( # produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3 # and needs ffmpeg for conversion. platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower() - want_opus = (platform == "telegram") + want_opus = platform == "telegram" # Determine output path if output_path: @@ -300,47 +303,48 @@ def text_to_speech_tool( # Generate audio with the configured provider if provider == "elevenlabs": if not _HAS_ELEVENLABS: - return json.dumps({ - "success": False, - "error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs" - }, ensure_ascii=False) + return json.dumps( + { + "success": False, + "error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs", + }, + ensure_ascii=False, + ) logger.info("Generating speech with ElevenLabs...") _generate_elevenlabs(text, file_str, tts_config) elif provider == "openai": if not _HAS_OPENAI: - return json.dumps({ - "success": False, - "error": "OpenAI provider selected but 'openai' package not installed." - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "OpenAI provider selected but 'openai' package not installed."}, + ensure_ascii=False, + ) logger.info("Generating speech with OpenAI TTS...") _generate_openai_tts(text, file_str, tts_config) else: # Default: Edge TTS (free) if not _HAS_EDGE_TTS: - return json.dumps({ - "success": False, - "error": "Edge TTS not available. Run: pip install edge-tts" - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": "Edge TTS not available. Run: pip install edge-tts"}, ensure_ascii=False + ) logger.info("Generating speech with Edge TTS...") # Edge TTS is async, run it try: loop = asyncio.get_running_loop() import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - pool.submit( - lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config)) - ).result(timeout=60) + pool.submit(lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))).result(timeout=60) except RuntimeError: asyncio.run(_generate_edge_tts(text, file_str, tts_config)) # Check the file was actually created if not os.path.exists(file_str) or os.path.getsize(file_str) == 0: - return json.dumps({ - "success": False, - "error": f"TTS generation produced no output (provider: {provider})" - }, ensure_ascii=False) + return json.dumps( + {"success": False, "error": f"TTS generation produced no output (provider: {provider})"}, + ensure_ascii=False, + ) # Try Opus conversion for Telegram compatibility (Edge TTS only outputs MP3) voice_compatible = False @@ -361,13 +365,16 @@ def text_to_speech_tool( if voice_compatible: media_tag = f"[[audio_as_voice]]\n{media_tag}" - return json.dumps({ - "success": True, - "file_path": file_str, - "media_tag": media_tag, - "provider": provider, - "voice_compatible": voice_compatible, - }, ensure_ascii=False) + return json.dumps( + { + "success": True, + "file_path": file_str, + "media_tag": media_tag, + "provider": provider, + "voice_compatible": voice_compatible, + }, + ensure_ascii=False, + ) except Exception as e: error_msg = f"TTS generation failed ({provider}): {e}" @@ -404,7 +411,7 @@ if __name__ == "__main__": print("🔊 Text-to-Speech Tool Module") print("=" * 50) - print(f"\nProvider availability:") + print("\nProvider availability:") print(f" Edge TTS: {'✅ installed' if _HAS_EDGE_TTS else '❌ not installed (pip install edge-tts)'}") print(f" ElevenLabs: {'✅ installed' if _HAS_ELEVENLABS else '❌ not installed (pip install elevenlabs)'}") print(f" API Key: {'✅ set' if os.getenv('ELEVENLABS_API_KEY') else '❌ not set'}") @@ -429,25 +436,20 @@ TTS_SCHEMA = { "parameters": { "type": "object", "properties": { - "text": { - "type": "string", - "description": "The text to convert to speech. Keep under 4000 characters." - }, + "text": {"type": "string", "description": "The text to convert to speech. Keep under 4000 characters."}, "output_path": { "type": "string", - "description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/.mp3" - } + "description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/.mp3", + }, }, - "required": ["text"] - } + "required": ["text"], + }, } registry.register( name="text_to_speech", toolset="tts", schema=TTS_SCHEMA, - handler=lambda args, **kw: text_to_speech_tool( - text=args.get("text", ""), - output_path=args.get("output_path")), + handler=lambda args, **kw: text_to_speech_tool(text=args.get("text", ""), output_path=args.get("output_path")), check_fn=check_tts_requirements, ) diff --git a/tools/vision_tools.py b/tools/vision_tools.py index d91051175a..f06df2b15d 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -19,7 +19,7 @@ Features: Usage: from vision_tools import vision_analyze_tool import asyncio - + # Analyze an image result = await vision_analyze_tool( image_url="https://example.com/image.jpg", @@ -33,11 +33,14 @@ import json import logging import os import uuid +from collections.abc import Awaitable from pathlib import Path -from typing import Any, Awaitable, Dict, Optional +from typing import Any from urllib.parse import urlparse + import httpx from openai import AsyncOpenAI + from agent.auxiliary_client import get_vision_auxiliary_client from tools.debug_helpers import DebugSession @@ -55,7 +58,7 @@ if _aux_sync_client is not None: _async_kwargs["default_headers"] = { "HTTP-Referer": "https://github.com/NousResearch/hermes-agent", "X-OpenRouter-Title": "Hermes Agent", - "X-OpenRouter-Categories": "productivity,cli-agent", + "X-OpenRouter-Categories": "productivity,cli-agent", } _aux_async_client = AsyncOpenAI(**_async_kwargs) @@ -65,10 +68,10 @@ _debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG") def _validate_image_url(url: str) -> bool: """ Basic validation of image URL format. - + Args: url (str): The URL to validate - + Returns: bool: True if URL appears to be valid, False otherwise """ @@ -91,23 +94,22 @@ def _validate_image_url(url: str) -> bool: async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path: """ Download an image from a URL to a local destination (async) with retry logic. - + Args: image_url (str): The URL of the image to download destination (Path): The path where the image should be saved max_retries (int): Maximum number of retry attempts (default: 3) - + Returns: Path: The path to the downloaded image - + Raises: Exception: If download fails after all retries """ - import asyncio - + # Create parent directories if they don't exist destination.parent.mkdir(parents=True, exist_ok=True) - + last_error = None for attempt in range(max_retries): try: @@ -122,10 +124,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = }, ) response.raise_for_status() - + # Save the image content destination.write_bytes(response.content) - + return destination except Exception as e: last_error = e @@ -141,56 +143,56 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = str(e)[:100], exc_info=True, ) - + raise last_error def _determine_mime_type(image_path: Path) -> str: """ Determine the MIME type of an image based on its file extension. - + Args: image_path (Path): Path to the image file - + Returns: str: The MIME type (defaults to image/jpeg if unknown) """ extension = image_path.suffix.lower() mime_types = { - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.png': 'image/png', - '.gif': 'image/gif', - '.bmp': 'image/bmp', - '.webp': 'image/webp', - '.svg': 'image/svg+xml' + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".bmp": "image/bmp", + ".webp": "image/webp", + ".svg": "image/svg+xml", } - return mime_types.get(extension, 'image/jpeg') + return mime_types.get(extension, "image/jpeg") -def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str: +def _image_to_base64_data_url(image_path: Path, mime_type: str | None = None) -> str: """ Convert an image file to a base64-encoded data URL. - + Args: image_path (Path): Path to the image file mime_type (Optional[str]): MIME type of the image (auto-detected if None) - + Returns: str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...") """ # Read the image as bytes data = image_path.read_bytes() - + # Encode to base64 encoded = base64.b64encode(data).decode("ascii") - + # Determine MIME type mime = mime_type or _determine_mime_type(image_path) - + # Create data URL data_url = f"data:{mime};base64,{encoded}" - + return data_url @@ -201,31 +203,31 @@ async def vision_analyze_tool( ) -> str: """ Analyze an image from a URL or local file path using vision AI. - + This tool accepts either an HTTP/HTTPS URL or a local file path. For URLs, it downloads the image first. In both cases, the image is converted to base64 and processed using Gemini 3 Flash Preview via OpenRouter API. - + The user_prompt parameter is expected to be pre-formatted by the calling function (typically model_tools.py) to include both full description requests and specific questions. - + Args: image_url (str): The URL or local file path of the image to analyze. Accepts http://, https:// URLs or absolute/relative file paths. user_prompt (str): The pre-formatted prompt for the vision model model (str): The vision model to use (default: google/gemini-3-flash-preview) - + Returns: str: JSON string containing the analysis results with the following structure: { "success": bool, "analysis": str (defaults to error message if None) } - + Raises: Exception: If download fails, analysis fails, or API key is not set - + Note: - For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up - For local file paths, the file is used directly and NOT deleted @@ -235,36 +237,41 @@ async def vision_analyze_tool( "parameters": { "image_url": image_url, "user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt, - "model": model + "model": model, }, "error": None, "success": False, "analysis_length": 0, "model_used": model, - "image_size_bytes": 0 + "image_size_bytes": 0, } - + temp_image_path = None # Track whether we should clean up the file after processing. # Local files (e.g. from the image cache) should NOT be deleted. should_cleanup = True - + try: from tools.interrupt import is_interrupted + if is_interrupted(): return json.dumps({"success": False, "error": "Interrupted"}) logger.info("Analyzing image: %s", image_url[:60]) logger.info("User prompt: %s", user_prompt[:100]) - + # Check auxiliary vision client availability if _aux_async_client is None or DEFAULT_VISION_MODEL is None: - return json.dumps({ - "success": False, - "analysis": "Vision analysis unavailable: no auxiliary vision model configured. " - "Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools." - }, indent=2, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "analysis": "Vision analysis unavailable: no auxiliary vision model configured. " + "Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.", + }, + indent=2, + ensure_ascii=False, + ) + # Determine if this is a local file path or a remote URL local_path = Path(image_url) if local_path.is_file(): @@ -280,50 +287,41 @@ async def vision_analyze_tool( await _download_image(image_url, temp_image_path) should_cleanup = True else: - raise ValueError( - "Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path." - ) - + raise ValueError("Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path.") + # Get image file size for logging image_size_bytes = temp_image_path.stat().st_size image_size_kb = image_size_bytes / 1024 logger.info("Image ready (%.1f KB)", image_size_kb) - + # Convert image to base64 data URL logger.info("Converting image to base64...") image_data_url = _image_to_base64_data_url(temp_image_path) # Calculate size in KB for better readability data_size_kb = len(image_data_url) / 1024 logger.info("Image converted to base64 (%.1f KB)", data_size_kb) - + debug_call_data["image_size_bytes"] = image_size_bytes - + # Use the prompt as provided (model_tools.py now handles full description formatting) comprehensive_prompt = user_prompt - + # Prepare the message with base64-encoded image messages = [ { "role": "user", "content": [ - { - "type": "text", - "text": comprehensive_prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_data_url - } - } - ] + {"type": "text", "text": comprehensive_prompt}, + {"type": "image_url", "image_url": {"url": image_data_url}}, + ], } ] - + logger.info("Processing image with %s...", model) - + # Call the vision API - from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param + from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body + _extra = get_auxiliary_extra_body() response = await _aux_async_client.chat.completions.create( model=model, @@ -332,44 +330,44 @@ async def vision_analyze_tool( **auxiliary_max_tokens_param(2000), **({} if not _extra else {"extra_body": _extra}), ) - + # Extract the analysis analysis = response.choices[0].message.content.strip() analysis_length = len(analysis) - + logger.info("Image analysis completed (%s characters)", analysis_length) - + # Prepare successful response result = { "success": True, - "analysis": analysis or "There was a problem with the request and the image could not be analyzed." + "analysis": analysis or "There was a problem with the request and the image could not be analyzed.", } - + debug_call_data["success"] = True debug_call_data["analysis_length"] = analysis_length - + # Log debug information _debug.log_call("vision_analyze_tool", debug_call_data) _debug.save() - + return json.dumps(result, indent=2, ensure_ascii=False) - + except Exception as e: error_msg = f"Error analyzing image: {str(e)}" logger.error("%s", error_msg, exc_info=True) - + # Prepare error response result = { "success": False, - "analysis": "There was a problem with the request and the image could not be analyzed." + "analysis": "There was a problem with the request and the image could not be analyzed.", } - + debug_call_data["error"] = error_msg _debug.log_call("vision_analyze_tool", debug_call_data) _debug.save() - + return json.dumps(result, indent=2, ensure_ascii=False) - + finally: # Clean up temporary image file (but NOT local/cached files) if should_cleanup and temp_image_path and temp_image_path.exists(): @@ -377,9 +375,7 @@ async def vision_analyze_tool( temp_image_path.unlink() logger.debug("Cleaned up temporary image file") except Exception as cleanup_error: - logger.warning( - "Could not delete temporary file: %s", cleanup_error, exc_info=True - ) + logger.warning("Could not delete temporary file: %s", cleanup_error, exc_info=True) def check_vision_requirements() -> bool: @@ -387,10 +383,10 @@ def check_vision_requirements() -> bool: return _aux_async_client is not None -def get_debug_session_info() -> Dict[str, Any]: +def get_debug_session_info() -> dict[str, Any]: """ Get information about the current debug session. - + Returns: Dict[str, Any]: Dictionary containing debug session information """ @@ -403,27 +399,27 @@ if __name__ == "__main__": """ print("👁️ Vision Tools Module") print("=" * 40) - + # Check if vision model is available api_available = check_vision_requirements() - + if not api_available: print("❌ No auxiliary vision model available") print("Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.") exit(1) else: print(f"✅ Vision model available: {DEFAULT_VISION_MODEL}") - + print("🛠️ Vision tools ready for use!") print(f"🧠 Using model: {DEFAULT_VISION_MODEL}") - + # Show debug mode status if _debug.active: print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}") print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{_debug.session_id}.json") else: print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)") - + print("\nBasic usage:") print(" from vision_tools import vision_analyze_tool") print(" import asyncio") @@ -435,14 +431,14 @@ if __name__ == "__main__": print(" )") print(" print(result)") print(" asyncio.run(main())") - + print("\nExample prompts:") print(" - 'What architectural style is this building?'") print(" - 'Describe the emotions and mood in this image'") print(" - 'What text can you read in this image?'") print(" - 'Identify any safety hazards visible'") print(" - 'What products or brands are shown?'") - + print("\nDebug mode:") print(" # Enable debug logging") print(" export VISION_TOOLS_DEBUG=true") @@ -461,30 +457,24 @@ VISION_ANALYZE_SCHEMA = { "parameters": { "type": "object", "properties": { - "image_url": { - "type": "string", - "description": "Image URL (http/https) or local file path to analyze." - }, + "image_url": {"type": "string", "description": "Image URL (http/https) or local file path to analyze."}, "question": { "type": "string", - "description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question." - } + "description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question.", + }, }, - "required": ["image_url", "question"] - } + "required": ["image_url", "question"], + }, } -def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]: +def _handle_vision_analyze(args: dict[str, Any], **kw: Any) -> Awaitable[str]: image_url = args.get("image_url", "") question = args.get("question", "") full_prompt = ( - "Fully describe and explain everything about this image, then answer the " - f"following question:\n\n{question}" + f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}" ) - model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip() - or DEFAULT_VISION_MODEL - or "google/gemini-3-flash-preview") + model = os.getenv("AUXILIARY_VISION_MODEL", "").strip() or DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview" return vision_analyze_tool(image_url, full_prompt, model) diff --git a/tools/web_tools.py b/tools/web_tools.py index e99d94fb0d..900bb7fa00 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -25,29 +25,30 @@ Debug Mode: Usage: from web_tools import web_search_tool, web_extract_tool, web_crawl_tool - + # Search the web results = web_search_tool("Python machine learning libraries", limit=3) - - # Extract content from URLs + + # Extract content from URLs content = web_extract_tool(["https://example.com"], format="markdown") - + # Crawl a website crawl_data = web_crawl_tool("example.com", "Find contact information") """ -#TODO: Search Capabilities over the scraped pages -#TODO: Store the pages in something -#TODO: Tool to see what pages are available/saved to search over +# TODO: Search Capabilities over the scraped pages +# TODO: Store the pages in something +# TODO: Tool to see what pages are available/saved to search over +import asyncio import json import logging import os import re -import asyncio -from typing import List, Dict, Any, Optional +from typing import Any + from firecrawl import Firecrawl -from openai import AsyncOpenAI + from agent.auxiliary_client import get_async_text_auxiliary_client from tools.debug_helpers import DebugSession @@ -55,6 +56,7 @@ logger = logging.getLogger(__name__) _firecrawl_client = None + def _get_firecrawl_client(): """Get or create the Firecrawl client (lazy initialization). @@ -81,6 +83,7 @@ def _get_firecrawl_client(): _firecrawl_client = Firecrawl(**kwargs) return _firecrawl_client + DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 # Resolve async auxiliary client at module level. @@ -88,61 +91,58 @@ DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 _aux_async_client, _DEFAULT_SUMMARIZER_MODEL = get_async_text_auxiliary_client("web_extract") # Allow per-task override via config.yaml auxiliary.web_extract_model -DEFAULT_SUMMARIZER_MODEL = ( - os.getenv("AUXILIARY_WEB_EXTRACT_MODEL", "").strip() - or _DEFAULT_SUMMARIZER_MODEL -) +DEFAULT_SUMMARIZER_MODEL = os.getenv("AUXILIARY_WEB_EXTRACT_MODEL", "").strip() or _DEFAULT_SUMMARIZER_MODEL _debug = DebugSession("web_tools", env_var="WEB_TOOLS_DEBUG") async def process_content_with_llm( - content: str, - url: str = "", + content: str, + url: str = "", title: str = "", model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION -) -> Optional[str]: + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION, +) -> str | None: """ Process web content using LLM to create intelligent summaries with key excerpts. - - This function uses Gemini 3 Flash Preview (or specified model) via OpenRouter API + + This function uses Gemini 3 Flash Preview (or specified model) via OpenRouter API to intelligently extract key information and create markdown summaries, significantly reducing token usage while preserving all important information. - + For very large content (>500k chars), uses chunked processing with synthesis. For extremely large content (>2M chars), refuses to process entirely. - + Args: content (str): The raw content to process url (str): The source URL (for context, optional) title (str): The page title (for context, optional) model (str): The model to use for processing (default: google/gemini-3-flash-preview) min_length (int): Minimum content length to trigger processing (default: 5000) - + Returns: Optional[str]: Processed markdown content, or None if content too short or processing fails """ # Size thresholds MAX_CONTENT_SIZE = 2_000_000 # 2M chars - refuse entirely above this - CHUNK_THRESHOLD = 500_000 # 500k chars - use chunked processing above this - CHUNK_SIZE = 100_000 # 100k chars per chunk - MAX_OUTPUT_SIZE = 5000 # Hard cap on final output size - + CHUNK_THRESHOLD = 500_000 # 500k chars - use chunked processing above this + CHUNK_SIZE = 100_000 # 100k chars per chunk + MAX_OUTPUT_SIZE = 5000 # Hard cap on final output size + try: content_len = len(content) - + # Refuse if content is absurdly large if content_len > MAX_CONTENT_SIZE: size_mb = content_len / 1_000_000 logger.warning("Content too large (%.1fMB > 2MB limit). Refusing to process.", size_mb) return f"[Content too large to process: {size_mb:.1f}MB. Try using web_crawl with specific extraction instructions, or search for a more focused source.]" - + # Skip processing if content is too short if content_len < min_length: logger.debug("Content too short (%d < %d chars), skipping LLM processing", content_len, min_length) return None - + # Create context information context_info = [] if title: @@ -150,47 +150,44 @@ async def process_content_with_llm( if url: context_info.append(f"Source: {url}") context_str = "\n".join(context_info) + "\n\n" if context_info else "" - + # Check if we need chunked processing if content_len > CHUNK_THRESHOLD: logger.info("Content large (%d chars). Using chunked processing...", content_len) - return await _process_large_content_chunked( - content, context_str, model, CHUNK_SIZE, MAX_OUTPUT_SIZE - ) - + return await _process_large_content_chunked(content, context_str, model, CHUNK_SIZE, MAX_OUTPUT_SIZE) + # Standard single-pass processing for normal content logger.info("Processing content with LLM (%d characters)", content_len) - + processed_content = await _call_summarizer_llm(content, context_str, model) - + if processed_content: # Enforce output cap if len(processed_content) > MAX_OUTPUT_SIZE: - processed_content = processed_content[:MAX_OUTPUT_SIZE] + "\n\n[... summary truncated for context management ...]" - + processed_content = ( + processed_content[:MAX_OUTPUT_SIZE] + "\n\n[... summary truncated for context management ...]" + ) + # Log compression metrics processed_length = len(processed_content) compression_ratio = processed_length / content_len if content_len > 0 else 1.0 - logger.info("Content processed: %d -> %d chars (%.1f%%)", content_len, processed_length, compression_ratio * 100) - + logger.info( + "Content processed: %d -> %d chars (%.1f%%)", content_len, processed_length, compression_ratio * 100 + ) + return processed_content - + except Exception as e: logger.debug("Error processing content with LLM: %s", e) return f"[Failed to process content: {str(e)[:100]}. Content size: {len(content):,} chars]" async def _call_summarizer_llm( - content: str, - context_str: str, - model: str, - max_tokens: int = 20000, - is_chunk: bool = False, - chunk_info: str = "" -) -> Optional[str]: + content: str, context_str: str, model: str, max_tokens: int = 20000, is_chunk: bool = False, chunk_info: str = "" +) -> str | None: """ Make a single LLM call to summarize content. - + Args: content: The content to summarize context_str: Context information (title, URL) @@ -198,7 +195,7 @@ async def _call_summarizer_llm( max_tokens: Maximum output tokens is_chunk: Whether this is a chunk of a larger document chunk_info: Information about chunk position (e.g., "Chunk 2/5") - + Returns: Summarized content or None on failure """ @@ -252,14 +249,12 @@ Create a markdown summary that captures all key information in a well-organized, if _aux_async_client is None: logger.warning("No auxiliary model available for web content processing") return None - from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param + from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body + _extra = get_auxiliary_extra_body() response = await _aux_async_client.chat.completions.create( model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ], + messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], temperature=0.1, **auxiliary_max_tokens_param(max_tokens), **({} if not _extra else {"extra_body": _extra}), @@ -268,94 +263,93 @@ Create a markdown summary that captures all key information in a well-organized, except Exception as api_error: last_error = api_error if attempt < max_retries - 1: - logger.warning("LLM API call failed (attempt %d/%d): %s", attempt + 1, max_retries, str(api_error)[:100]) + logger.warning( + "LLM API call failed (attempt %d/%d): %s", attempt + 1, max_retries, str(api_error)[:100] + ) logger.warning("Retrying in %ds...", retry_delay) await asyncio.sleep(retry_delay) retry_delay = min(retry_delay * 2, 60) else: raise last_error - + return None async def _process_large_content_chunked( - content: str, - context_str: str, - model: str, - chunk_size: int, - max_output_size: int -) -> Optional[str]: + content: str, context_str: str, model: str, chunk_size: int, max_output_size: int +) -> str | None: """ Process large content by chunking, summarizing each chunk in parallel, then synthesizing the summaries. - + Args: content: The large content to process context_str: Context information model: Model to use chunk_size: Size of each chunk in characters max_output_size: Maximum final output size - + Returns: Synthesized summary or None on failure """ # Split content into chunks chunks = [] for i in range(0, len(content), chunk_size): - chunk = content[i:i + chunk_size] + chunk = content[i : i + chunk_size] chunks.append(chunk) - + logger.info("Split into %d chunks of ~%d chars each", len(chunks), chunk_size) - + # Summarize each chunk in parallel - async def summarize_chunk(chunk_idx: int, chunk_content: str) -> tuple[int, Optional[str]]: + async def summarize_chunk(chunk_idx: int, chunk_content: str) -> tuple[int, str | None]: """Summarize a single chunk.""" try: chunk_info = f"[Processing chunk {chunk_idx + 1} of {len(chunks)}]" summary = await _call_summarizer_llm( - chunk_content, - context_str, - model, - max_tokens=10000, - is_chunk=True, - chunk_info=chunk_info + chunk_content, context_str, model, max_tokens=10000, is_chunk=True, chunk_info=chunk_info ) if summary: - logger.info("Chunk %d/%d summarized: %d -> %d chars", chunk_idx + 1, len(chunks), len(chunk_content), len(summary)) + logger.info( + "Chunk %d/%d summarized: %d -> %d chars", + chunk_idx + 1, + len(chunks), + len(chunk_content), + len(summary), + ) return chunk_idx, summary except Exception as e: logger.warning("Chunk %d/%d failed: %s", chunk_idx + 1, len(chunks), str(e)[:50]) return chunk_idx, None - + # Run all chunk summarizations in parallel tasks = [summarize_chunk(i, chunk) for i, chunk in enumerate(chunks)] results = await asyncio.gather(*tasks) - + # Collect successful summaries in order summaries = [] for chunk_idx, summary in sorted(results, key=lambda x: x[0]): if summary: summaries.append(f"## Section {chunk_idx + 1}\n{summary}") - + if not summaries: logger.debug("All chunk summarizations failed") return "[Failed to process large content: all chunk summarizations failed]" - + logger.info("Got %d/%d chunk summaries", len(summaries), len(chunks)) - + # If only one chunk succeeded, just return it (with cap) if len(summaries) == 1: result = summaries[0] if len(result) > max_output_size: result = result[:max_output_size] + "\n\n[... truncated ...]" return result - + # Synthesize the summaries into a final summary logger.info("Synthesizing %d summaries...", len(summaries)) - + combined_summaries = "\n\n---\n\n".join(summaries) - - synthesis_prompt = f"""You have been given summaries of different sections of a large document. + + synthesis_prompt = f"""You have been given summaries of different sections of a large document. Synthesize these into ONE cohesive, comprehensive summary that: 1. Removes redundancy between sections 2. Preserves all key facts, figures, and actionable information @@ -375,31 +369,35 @@ Create a single, unified markdown summary.""" fallback = fallback[:max_output_size] + "\n\n[... truncated ...]" return fallback - from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param + from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body + _extra = get_auxiliary_extra_body() response = await _aux_async_client.chat.completions.create( model=model, messages=[ - {"role": "system", "content": "You synthesize multiple summaries into one cohesive, comprehensive summary. Be thorough but concise."}, - {"role": "user", "content": synthesis_prompt} + { + "role": "system", + "content": "You synthesize multiple summaries into one cohesive, comprehensive summary. Be thorough but concise.", + }, + {"role": "user", "content": synthesis_prompt}, ], temperature=0.1, **auxiliary_max_tokens_param(20000), **({} if not _extra else {"extra_body": _extra}), ) final_summary = response.choices[0].message.content.strip() - + # Enforce hard cap if len(final_summary) > max_output_size: final_summary = final_summary[:max_output_size] + "\n\n[... summary truncated for context management ...]" - + original_len = len(content) final_len = len(final_summary) compression = final_len / original_len if original_len > 0 else 1.0 - + logger.info("Synthesis complete: %d -> %d chars (%.2f%%)", original_len, final_len, compression * 100) return final_summary - + except Exception as e: logger.warning("Synthesis failed: %s", str(e)[:100]) # Fall back to concatenated summaries with truncation @@ -412,50 +410,50 @@ Create a single, unified markdown summary.""" def clean_base64_images(text: str) -> str: """ Remove base64 encoded images from text to reduce token count and clutter. - + This function finds and removes base64 encoded images in various formats: - (data:image/png;base64,...) - (data:image/jpeg;base64,...) - (data:image/svg+xml;base64,...) - data:image/[type];base64,... (without parentheses) - + Args: text: The text content to clean - + Returns: Cleaned text with base64 images replaced with placeholders """ # Pattern to match base64 encoded images wrapped in parentheses # Matches: (data:image/[type];base64,[base64-string]) - base64_with_parens_pattern = r'\(data:image/[^;]+;base64,[A-Za-z0-9+/=]+\)' - + base64_with_parens_pattern = r"\(data:image/[^;]+;base64,[A-Za-z0-9+/=]+\)" + # Pattern to match base64 encoded images without parentheses # Matches: data:image/[type];base64,[base64-string] - base64_pattern = r'data:image/[^;]+;base64,[A-Za-z0-9+/=]+' - + base64_pattern = r"data:image/[^;]+;base64,[A-Za-z0-9+/=]+" + # Replace parentheses-wrapped images first - cleaned_text = re.sub(base64_with_parens_pattern, '[BASE64_IMAGE_REMOVED]', text) - + cleaned_text = re.sub(base64_with_parens_pattern, "[BASE64_IMAGE_REMOVED]", text) + # Then replace any remaining non-parentheses images - cleaned_text = re.sub(base64_pattern, '[BASE64_IMAGE_REMOVED]', cleaned_text) - + cleaned_text = re.sub(base64_pattern, "[BASE64_IMAGE_REMOVED]", cleaned_text) + return cleaned_text def web_search_tool(query: str, limit: int = 5) -> str: """ Search the web for information using available search API backend. - + This function provides a generic interface for web search that can work with multiple backends. Currently uses Firecrawl. - + Note: This function returns search result metadata only (URLs, titles, descriptions). Use web_extract_tool to get full content from specific URLs. - + Args: query (str): The search query to look up limit (int): Maximum number of results to return (default: 5) - + Returns: str: JSON string containing search results with the following structure: { @@ -472,122 +470,112 @@ def web_search_tool(query: str, limit: int = 5) -> str: ] } } - + Raises: Exception: If search fails or API key is not set """ debug_call_data = { - "parameters": { - "query": query, - "limit": limit - }, + "parameters": {"query": query, "limit": limit}, "error": None, "results_count": 0, "original_response_size": 0, - "final_response_size": 0 + "final_response_size": 0, } - + try: from tools.interrupt import is_interrupted + if is_interrupted(): return json.dumps({"error": "Interrupted", "success": False}) logger.info("Searching the web for: '%s' (limit: %d)", query, limit) - - response = _get_firecrawl_client().search( - query=query, - limit=limit - ) - + + response = _get_firecrawl_client().search(query=query, limit=limit) + # The response is a SearchData object with web, news, and images attributes # When not scraping, the results are directly in these attributes web_results = [] - + # Check if response has web attribute (SearchData object) - if hasattr(response, 'web'): + if hasattr(response, "web"): # Response is a SearchData object with web attribute if response.web: # Convert each SearchResultWeb object to dict for result in response.web: - if hasattr(result, 'model_dump'): + if hasattr(result, "model_dump"): # Pydantic model - use model_dump web_results.append(result.model_dump()) - elif hasattr(result, '__dict__'): + elif hasattr(result, "__dict__"): # Regular object - use __dict__ web_results.append(result.__dict__) elif isinstance(result, dict): # Already a dict web_results.append(result) - elif hasattr(response, 'model_dump'): + elif hasattr(response, "model_dump"): # Response has model_dump method - use it to get dict response_dict = response.model_dump() - if 'web' in response_dict and response_dict['web']: - web_results = response_dict['web'] + if "web" in response_dict and response_dict["web"]: + web_results = response_dict["web"] elif isinstance(response, dict): # Response is already a dictionary - if 'web' in response and response['web']: - web_results = response['web'] - + if "web" in response and response["web"]: + web_results = response["web"] + results_count = len(web_results) logger.info("Found %d search results", results_count) - + # Build response with just search metadata (URLs, titles, descriptions) - response_data = { - "success": True, - "data": { - "web": web_results - } - } - + response_data = {"success": True, "data": {"web": web_results}} + # Capture debug information debug_call_data["results_count"] = results_count - + # Convert to JSON result_json = json.dumps(response_data, indent=2, ensure_ascii=False) - + debug_call_data["final_response_size"] = len(result_json) - + # Log debug information _debug.log_call("web_search_tool", debug_call_data) _debug.save() - + return result_json - + except Exception as e: error_msg = f"Error searching web: {str(e)}" logger.debug("%s", error_msg) - + debug_call_data["error"] = error_msg _debug.log_call("web_search_tool", debug_call_data) _debug.save() - + return json.dumps({"error": error_msg}, ensure_ascii=False) async def web_extract_tool( - urls: List[str], - format: str = None, + urls: list[str], + format: str = None, use_llm_processing: bool = True, model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION, ) -> str: """ Extract content from specific web pages using available extraction API backend. - + This function provides a generic interface for web content extraction that can work with multiple backends. Currently uses Firecrawl. - + Args: urls (List[str]): List of URLs to extract content from format (str): Desired output format ("markdown" or "html", optional) use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) model (str): The model to use for LLM processing (default: google/gemini-3-flash-preview) min_length (int): Minimum content length to trigger LLM processing (default: 5000) - + Returns: str: JSON string containing extracted content. If LLM processing is enabled and successful, the 'content' field will contain the processed markdown summary instead of raw content. - + Raises: Exception: If extraction fails or API key is not set """ @@ -597,7 +585,7 @@ async def web_extract_tool( "format": format, "use_llm_processing": use_llm_processing, "model": model, - "min_length": min_length + "min_length": min_length, }, "error": None, "pages_extracted": 0, @@ -605,14 +593,14 @@ async def web_extract_tool( "original_response_size": 0, "final_response_size": 0, "compression_metrics": [], - "processing_applied": [] + "processing_applied": [], } - + try: logger.info("Extracting content from %d URL(s)", len(urls)) - + # Determine requested formats for Firecrawl v2 - formats: List[str] = [] + formats: list[str] = [] if format == "markdown": formats = ["markdown"] elif format == "html": @@ -620,12 +608,13 @@ async def web_extract_tool( else: # Default: request markdown for LLM-readiness and include html as backup formats = ["markdown", "html"] - + # Always use individual scraping for simplicity and reliability # Batch scraping adds complexity without much benefit for small numbers of URLs - results: List[Dict[str, Any]] = [] - + results: list[dict[str, Any]] = [] + from tools.interrupt import is_interrupted as _is_interrupted + for url in urls: if _is_interrupted(): results.append({"url": url, "error": "Interrupted", "title": ""}) @@ -633,34 +622,31 @@ async def web_extract_tool( try: logger.info("Scraping: %s", url) - scrape_result = _get_firecrawl_client().scrape( - url=url, - formats=formats - ) - + scrape_result = _get_firecrawl_client().scrape(url=url, formats=formats) + # Process the result - properly handle object serialization metadata = {} title = "" content_markdown = None content_html = None - + # Extract data from the scrape result - if hasattr(scrape_result, 'model_dump'): + if hasattr(scrape_result, "model_dump"): # Pydantic model - use model_dump to get dict result_dict = scrape_result.model_dump() - content_markdown = result_dict.get('markdown') - content_html = result_dict.get('html') - metadata = result_dict.get('metadata', {}) - elif hasattr(scrape_result, '__dict__'): + content_markdown = result_dict.get("markdown") + content_html = result_dict.get("html") + metadata = result_dict.get("metadata", {}) + elif hasattr(scrape_result, "__dict__"): # Regular object with attributes - content_markdown = getattr(scrape_result, 'markdown', None) - content_html = getattr(scrape_result, 'html', None) - + content_markdown = getattr(scrape_result, "markdown", None) + content_html = getattr(scrape_result, "html", None) + # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(scrape_result, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): + metadata_obj = getattr(scrape_result, "metadata", {}) + if hasattr(metadata_obj, "model_dump"): metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): + elif hasattr(metadata_obj, "__dict__"): metadata = metadata_obj.__dict__ elif isinstance(metadata_obj, dict): metadata = metadata_obj @@ -668,87 +654,85 @@ async def web_extract_tool( metadata = {} elif isinstance(scrape_result, dict): # Already a dictionary - content_markdown = scrape_result.get('markdown') - content_html = scrape_result.get('html') - metadata = scrape_result.get('metadata', {}) - + content_markdown = scrape_result.get("markdown") + content_html = scrape_result.get("html") + metadata = scrape_result.get("metadata", {}) + # Ensure metadata is a dict (not an object) if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): + if hasattr(metadata, "model_dump"): metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): + elif hasattr(metadata, "__dict__"): metadata = metadata.__dict__ else: metadata = {} - + # Get title from metadata title = metadata.get("title", "") - + # Choose content based on requested format - chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" - - results.append({ - "url": metadata.get("sourceURL", url), - "title": title, - "content": chosen_content, - "raw_content": chosen_content, - "metadata": metadata # Now guaranteed to be a dict - }) - + chosen_content = ( + content_markdown + if (format == "markdown" or (format is None and content_markdown)) + else content_html or content_markdown or "" + ) + + results.append( + { + "url": metadata.get("sourceURL", url), + "title": title, + "content": chosen_content, + "raw_content": chosen_content, + "metadata": metadata, # Now guaranteed to be a dict + } + ) + except Exception as scrape_err: logger.debug("Scrape failed for %s: %s", url, scrape_err) - results.append({ - "url": url, - "title": "", - "content": "", - "raw_content": "", - "error": str(scrape_err) - }) + results.append({"url": url, "title": "", "content": "", "raw_content": "", "error": str(scrape_err)}) response = {"results": results} - - pages_extracted = len(response.get('results', [])) + + pages_extracted = len(response.get("results", [])) logger.info("Extracted content from %d pages", pages_extracted) - + debug_call_data["pages_extracted"] = pages_extracted debug_call_data["original_response_size"] = len(json.dumps(response)) - + # Process each result with LLM if enabled and auxiliary client is available if use_llm_processing and _aux_async_client is not None: logger.info("Processing extracted content with LLM (parallel)...") debug_call_data["processing_applied"].append("llm_processing") - + # Prepare tasks for parallel processing async def process_single_result(result): """Process a single result with LLM and return updated result with metrics.""" - url = result.get('url', 'Unknown URL') - title = result.get('title', '') - raw_content = result.get('raw_content', '') or result.get('content', '') - + url = result.get("url", "Unknown URL") + title = result.get("title", "") + raw_content = result.get("raw_content", "") or result.get("content", "") + if not raw_content: return result, None, "no_content" - + original_size = len(raw_content) - + # Process content with LLM - processed = await process_content_with_llm( - raw_content, url, title, model, min_length - ) - + processed = await process_content_with_llm(raw_content, url, title, model, min_length) + if processed: processed_size = len(processed) compression_ratio = processed_size / original_size if original_size > 0 else 1.0 - + # Update result with processed content - result['content'] = processed - result['raw_content'] = raw_content - + result["content"] = processed + result["raw_content"] = raw_content + metrics = { "url": url, "original_size": original_size, "processed_size": processed_size, "compression_ratio": compression_ratio, - "model_used": model + "model_used": model, } return result, metrics, "processed" else: @@ -758,18 +742,18 @@ async def web_extract_tool( "processed_size": original_size, "compression_ratio": 1.0, "model_used": None, - "reason": "content_too_short" + "reason": "content_too_short", } return result, metrics, "too_short" - + # Run all LLM processing in parallel - results_list = response.get('results', []) + results_list = response.get("results", []) tasks = [process_single_result(result) for result in results_list] processed_results = await asyncio.gather(*tasks) - + # Collect metrics and print results for result, metrics, status in processed_results: - url = result.get('url', 'Unknown URL') + url = result.get("url", "Unknown URL") if status == "processed": debug_call_data["compression_metrics"].append(metrics) debug_call_data["pages_processed_with_llm"] += 1 @@ -783,13 +767,13 @@ async def web_extract_tool( if use_llm_processing and _aux_async_client is None: logger.warning("LLM processing requested but no auxiliary model available, returning raw content") debug_call_data["processing_applied"].append("llm_processing_unavailable") - + # Print summary of extracted pages for debugging (original behavior) - for result in response.get('results', []): - url = result.get('url', 'Unknown URL') - content_length = len(result.get('raw_content', '')) + for result in response.get("results", []): + url = result.get("url", "Unknown URL") + content_length = len(result.get("raw_content", "")) logger.info("%s (%d characters)", url, content_length) - + # Trim output to minimal fields per entry: title, content, error trimmed_results = [ { @@ -806,46 +790,46 @@ async def web_extract_tool( result_json = json.dumps({"error": "Content was inaccessible or not found"}, ensure_ascii=False) cleaned_result = clean_base64_images(result_json) - + else: result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) - + cleaned_result = clean_base64_images(result_json) - + debug_call_data["final_response_size"] = len(cleaned_result) debug_call_data["processing_applied"].append("base64_image_removal") - + # Log debug information _debug.log_call("web_extract_tool", debug_call_data) _debug.save() - + return cleaned_result - + except Exception as e: error_msg = f"Error extracting content: {str(e)}" logger.debug("%s", error_msg) - + debug_call_data["error"] = error_msg _debug.log_call("web_extract_tool", debug_call_data) _debug.save() - + return json.dumps({"error": error_msg}, ensure_ascii=False) async def web_crawl_tool( - url: str, - instructions: str = None, - depth: str = "basic", + url: str, + instructions: str = None, + depth: str = "basic", use_llm_processing: bool = True, model: str = DEFAULT_SUMMARIZER_MODEL, - min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION + min_length: int = DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION, ) -> str: """ Crawl a website with specific instructions using available crawling API backend. - + This function provides a generic interface for web crawling that can work with multiple backends. Currently uses Firecrawl. - + Args: url (str): The base URL to crawl (can include or exclude https://) instructions (str): Instructions for what to crawl/extract using LLM intelligence (optional) @@ -853,12 +837,12 @@ async def web_crawl_tool( use_llm_processing (bool): Whether to process content with LLM for summarization (default: True) model (str): The model to use for LLM processing (default: google/gemini-3-flash-preview) min_length (int): Minimum content length to trigger LLM processing (default: 5000) - + Returns: str: JSON string containing crawled content. If LLM processing is enabled and successful, the 'content' field will contain the processed markdown summary instead of raw content. Each page is processed individually. - + Raises: Exception: If crawling fails or API key is not set """ @@ -869,7 +853,7 @@ async def web_crawl_tool( "depth": depth, "use_llm_processing": use_llm_processing, "model": model, - "min_length": min_length + "min_length": min_length, }, "error": None, "pages_crawled": 0, @@ -877,74 +861,74 @@ async def web_crawl_tool( "original_response_size": 0, "final_response_size": 0, "compression_metrics": [], - "processing_applied": [] + "processing_applied": [], } - + try: # Ensure URL has protocol - if not url.startswith(('http://', 'https://')): - url = f'https://{url}' + if not url.startswith(("http://", "https://")): + url = f"https://{url}" logger.info("Added https:// prefix to URL: %s", url) - + instructions_text = f" with instructions: '{instructions}'" if instructions else "" logger.info("Crawling %s%s", url, instructions_text) - + # Use Firecrawl's v2 crawl functionality # Docs: https://docs.firecrawl.dev/features/crawl # The crawl() method automatically waits for completion and returns all data - + # Build crawl parameters - keep it simple crawl_params = { "limit": 20, # Limit number of pages to crawl "scrape_options": { "formats": ["markdown"] # Just markdown for simplicity - } + }, } - + # Note: The 'prompt' parameter is not documented for crawl # Instructions are typically used with the Extract endpoint, not Crawl if instructions: logger.info("Instructions parameter ignored (not supported in crawl API)") - + from tools.interrupt import is_interrupted as _is_int + if _is_int(): return json.dumps({"error": "Interrupted", "success": False}) try: - crawl_result = _get_firecrawl_client().crawl( - url=url, - **crawl_params - ) + crawl_result = _get_firecrawl_client().crawl(url=url, **crawl_params) except Exception as e: logger.debug("Crawl API call failed: %s", e) raise - pages: List[Dict[str, Any]] = [] - + pages: list[dict[str, Any]] = [] + # Process crawl results - the crawl method returns a CrawlJob object with data attribute data_list = [] - + # The crawl_result is a CrawlJob object with a 'data' attribute containing list of Document objects - if hasattr(crawl_result, 'data'): + if hasattr(crawl_result, "data"): data_list = crawl_result.data if crawl_result.data else [] - logger.info("Status: %s", getattr(crawl_result, 'status', 'unknown')) + logger.info("Status: %s", getattr(crawl_result, "status", "unknown")) logger.info("Retrieved %d pages", len(data_list)) - + # Debug: Check other attributes if no data if not data_list: - logger.debug("CrawlJob attributes: %s", [attr for attr in dir(crawl_result) if not attr.startswith('_')]) - logger.debug("Status: %s", getattr(crawl_result, 'status', 'N/A')) - logger.debug("Total: %s", getattr(crawl_result, 'total', 'N/A')) - logger.debug("Completed: %s", getattr(crawl_result, 'completed', 'N/A')) - - elif isinstance(crawl_result, dict) and 'data' in crawl_result: + logger.debug( + "CrawlJob attributes: %s", [attr for attr in dir(crawl_result) if not attr.startswith("_")] + ) + logger.debug("Status: %s", getattr(crawl_result, "status", "N/A")) + logger.debug("Total: %s", getattr(crawl_result, "total", "N/A")) + logger.debug("Completed: %s", getattr(crawl_result, "completed", "N/A")) + + elif isinstance(crawl_result, dict) and "data" in crawl_result: data_list = crawl_result.get("data", []) else: logger.warning("Unexpected crawl result type") logger.debug("Result type: %s", type(crawl_result)) - if hasattr(crawl_result, '__dict__'): + if hasattr(crawl_result, "__dict__"): logger.debug("Result attributes: %s", list(crawl_result.__dict__.keys())) - + for item in data_list: # Process each crawled page - properly handle object serialization page_url = "Unknown URL" @@ -952,24 +936,24 @@ async def web_crawl_tool( content_markdown = None content_html = None metadata = {} - + # Extract data from the item - if hasattr(item, 'model_dump'): + if hasattr(item, "model_dump"): # Pydantic model - use model_dump to get dict item_dict = item.model_dump() - content_markdown = item_dict.get('markdown') - content_html = item_dict.get('html') - metadata = item_dict.get('metadata', {}) - elif hasattr(item, '__dict__'): + content_markdown = item_dict.get("markdown") + content_html = item_dict.get("html") + metadata = item_dict.get("metadata", {}) + elif hasattr(item, "__dict__"): # Regular object with attributes - content_markdown = getattr(item, 'markdown', None) - content_html = getattr(item, 'html', None) - + content_markdown = getattr(item, "markdown", None) + content_html = getattr(item, "html", None) + # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(item, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): + metadata_obj = getattr(item, "metadata", {}) + if hasattr(metadata_obj, "model_dump"): metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): + elif hasattr(metadata_obj, "__dict__"): metadata = metadata_obj.__dict__ elif isinstance(metadata_obj, dict): metadata = metadata_obj @@ -977,78 +961,78 @@ async def web_crawl_tool( metadata = {} elif isinstance(item, dict): # Already a dictionary - content_markdown = item.get('markdown') - content_html = item.get('html') - metadata = item.get('metadata', {}) - + content_markdown = item.get("markdown") + content_html = item.get("html") + metadata = item.get("metadata", {}) + # Ensure metadata is a dict (not an object) if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): + if hasattr(metadata, "model_dump"): metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): + elif hasattr(metadata, "__dict__"): metadata = metadata.__dict__ else: metadata = {} - + # Extract URL and title from metadata page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL")) title = metadata.get("title", "") - + # Choose content (prefer markdown) content = content_markdown or content_html or "" - - pages.append({ - "url": page_url, - "title": title, - "content": content, - "raw_content": content, - "metadata": metadata # Now guaranteed to be a dict - }) + + pages.append( + { + "url": page_url, + "title": title, + "content": content, + "raw_content": content, + "metadata": metadata, # Now guaranteed to be a dict + } + ) response = {"results": pages} - - pages_crawled = len(response.get('results', [])) + + pages_crawled = len(response.get("results", [])) logger.info("Crawled %d pages", pages_crawled) - + debug_call_data["pages_crawled"] = pages_crawled debug_call_data["original_response_size"] = len(json.dumps(response)) - + # Process each result with LLM if enabled and auxiliary client is available if use_llm_processing and _aux_async_client is not None: logger.info("Processing crawled content with LLM (parallel)...") debug_call_data["processing_applied"].append("llm_processing") - + # Prepare tasks for parallel processing async def process_single_crawl_result(result): """Process a single crawl result with LLM and return updated result with metrics.""" - page_url = result.get('url', 'Unknown URL') - title = result.get('title', '') - content = result.get('content', '') - + page_url = result.get("url", "Unknown URL") + title = result.get("title", "") + content = result.get("content", "") + if not content: return result, None, "no_content" - + original_size = len(content) - + # Process content with LLM - processed = await process_content_with_llm( - content, page_url, title, model, min_length - ) - + processed = await process_content_with_llm(content, page_url, title, model, min_length) + if processed: processed_size = len(processed) compression_ratio = processed_size / original_size if original_size > 0 else 1.0 - + # Update result with processed content - result['raw_content'] = content - result['content'] = processed - + result["raw_content"] = content + result["content"] = processed + metrics = { "url": page_url, "original_size": original_size, "processed_size": processed_size, "compression_ratio": compression_ratio, - "model_used": model + "model_used": model, } return result, metrics, "processed" else: @@ -1058,18 +1042,18 @@ async def web_crawl_tool( "processed_size": original_size, "compression_ratio": 1.0, "model_used": None, - "reason": "content_too_short" + "reason": "content_too_short", } return result, metrics, "too_short" - + # Run all LLM processing in parallel - results_list = response.get('results', []) + results_list = response.get("results", []) tasks = [process_single_crawl_result(result) for result in results_list] processed_results = await asyncio.gather(*tasks) - + # Collect metrics and print results for result, metrics, status in processed_results: - page_url = result.get('url', 'Unknown URL') + page_url = result.get("url", "Unknown URL") if status == "processed": debug_call_data["compression_metrics"].append(metrics) debug_call_data["pages_processed_with_llm"] += 1 @@ -1083,45 +1067,41 @@ async def web_crawl_tool( if use_llm_processing and _aux_async_client is None: logger.warning("LLM processing requested but no auxiliary model available, returning raw content") debug_call_data["processing_applied"].append("llm_processing_unavailable") - + # Print summary of crawled pages for debugging (original behavior) - for result in response.get('results', []): - page_url = result.get('url', 'Unknown URL') - content_length = len(result.get('content', '')) + for result in response.get("results", []): + page_url = result.get("url", "Unknown URL") + content_length = len(result.get("content", "")) logger.info("%s (%d characters)", page_url, content_length) - + # Trim output to minimal fields per entry: title, content, error trimmed_results = [ - { - "title": r.get("title", ""), - "content": r.get("content", ""), - "error": r.get("error") - } + {"title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error")} for r in response.get("results", []) ] trimmed_response = {"results": trimmed_results} - + result_json = json.dumps(trimmed_response, indent=2, ensure_ascii=False) # Clean base64 images from crawled content cleaned_result = clean_base64_images(result_json) - + debug_call_data["final_response_size"] = len(cleaned_result) debug_call_data["processing_applied"].append("base64_image_removal") - + # Log debug information _debug.log_call("web_crawl_tool", debug_call_data) _debug.save() - + return cleaned_result - + except Exception as e: error_msg = f"Error crawling website: {str(e)}" logger.debug("%s", error_msg) - + debug_call_data["error"] = error_msg _debug.log_call("web_crawl_tool", debug_call_data) _debug.save() - + return json.dumps({"error": error_msg}, ensure_ascii=False) @@ -1129,7 +1109,7 @@ async def web_crawl_tool( def check_firecrawl_api_key() -> bool: """ Check if the Firecrawl API key is available in environment variables. - + Returns: bool: True if API key is set, False otherwise """ @@ -1141,7 +1121,7 @@ def check_auxiliary_model() -> bool: return _aux_async_client is not None -def get_debug_session_info() -> Dict[str, Any]: +def get_debug_session_info() -> dict[str, Any]: """Get information about the current debug session.""" return _debug.get_session_info() @@ -1152,41 +1132,41 @@ if __name__ == "__main__": """ print("🌐 Standalone Web Tools Module") print("=" * 40) - + # Check if API keys are available firecrawl_available = check_firecrawl_api_key() nous_available = check_auxiliary_model() - + if not firecrawl_available: print("❌ FIRECRAWL_API_KEY environment variable not set") print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'") print("Get API key at: https://firecrawl.dev/") else: print("✅ Firecrawl API key found") - + if not nous_available: print("❌ No auxiliary model available for LLM content processing") print("Set OPENROUTER_API_KEY, configure Nous Portal, or set OPENAI_BASE_URL + OPENAI_API_KEY") print("⚠️ Without an auxiliary model, LLM content processing will be disabled") else: print(f"✅ Auxiliary model available: {DEFAULT_SUMMARIZER_MODEL}") - + if not firecrawl_available: exit(1) - + print("🛠️ Web tools ready for use!") - + if nous_available: print(f"🧠 LLM content processing available with {DEFAULT_SUMMARIZER_MODEL}") print(f" Default min length for processing: {DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION} chars") - + # Show debug mode status if _debug.active: print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}") print(f" Debug logs will be saved to: {_debug.log_dir}/web_tools_debug_{_debug.session_id}.json") else: print("🐛 Debug mode disabled (set WEB_TOOLS_DEBUG=true to enable)") - + print("\nBasic usage:") print(" from web_tools import web_search_tool, web_extract_tool, web_crawl_tool") print(" import asyncio") @@ -1199,7 +1179,7 @@ if __name__ == "__main__": print(" content = await web_extract_tool(['https://example.com'])") print(" crawl_data = await web_crawl_tool('example.com', 'Find docs')") print(" asyncio.run(main())") - + if nous_available: print("\nLLM-enhanced usage:") print(" # Content automatically processed for pages >5000 chars (default)") @@ -1215,7 +1195,7 @@ if __name__ == "__main__": print("") print(" # Disable LLM processing") print(" raw_content = await web_extract_tool(['https://example.com'], use_llm_processing=False)") - + print("\nDebug mode:") print(" # Enable debug logging") print(" export WEB_TOOLS_DEBUG=true") @@ -1225,8 +1205,8 @@ if __name__ == "__main__": print(" # - LLM compression metrics") print(" # - Final processed results") print(" # Logs saved to: ./logs/web_tools_debug_UUID.json") - - print(f"\n📝 Run 'python test_web_tools_llm.py' to test LLM processing capabilities") + + print("\n📝 Run 'python test_web_tools_llm.py' to test LLM processing capabilities") # --------------------------------------------------------------------------- @@ -1239,14 +1219,9 @@ WEB_SEARCH_SCHEMA = { "description": "Search the web for information on any topic. Returns up to 5 relevant results with titles, URLs, and descriptions.", "parameters": { "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query to look up on the web" - } - }, - "required": ["query"] - } + "properties": {"query": {"type": "string", "description": "The search query to look up on the web"}}, + "required": ["query"], + }, } WEB_EXTRACT_SCHEMA = { @@ -1259,11 +1234,11 @@ WEB_EXTRACT_SCHEMA = { "type": "array", "items": {"type": "string"}, "description": "List of URLs to extract content from (max 5 URLs per call)", - "maxItems": 5 + "maxItems": 5, } }, - "required": ["urls"] - } + "required": ["urls"], + }, } registry.register( @@ -1279,7 +1254,8 @@ registry.register( toolset="web", schema=WEB_EXTRACT_SCHEMA, handler=lambda args, **kw: web_extract_tool( - args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"), + args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown" + ), check_fn=check_firecrawl_api_key, requires_env=["FIRECRAWL_API_KEY"], is_async=True, diff --git a/toolsets.py b/toolsets.py index 87b48c7ecb..c5fc2b13b7 100644 --- a/toolsets.py +++ b/toolsets.py @@ -15,55 +15,75 @@ Features: Usage: from toolsets import get_toolset, resolve_toolset, get_all_toolsets - + # Get tools for a specific toolset tools = get_toolset("research") - + # Resolve a toolset to get all tool names (including from composed toolsets) all_tools = resolve_toolset("full_stack") """ -from typing import List, Dict, Any, Set, Optional - +from typing import Any # Shared tool list for CLI and all messaging platform toolsets. # Edit this once to update all platforms simultaneously. _HERMES_CORE_TOOLS = [ # Web - "web_search", "web_extract", + "web_search", + "web_extract", # Terminal + process management - "terminal", "process", + "terminal", + "process", # File manipulation - "read_file", "write_file", "patch", "search_files", + "read_file", + "write_file", + "patch", + "search_files", # Vision + image generation - "vision_analyze", "image_generate", + "vision_analyze", + "image_generate", # MoA "mixture_of_agents", # Skills - "skills_list", "skill_view", "skill_manage", + "skills_list", + "skill_view", + "skill_manage", # Browser automation - "browser_navigate", "browser_snapshot", "browser_click", - "browser_type", "browser_scroll", "browser_back", - "browser_press", "browser_close", "browser_get_images", + "browser_navigate", + "browser_snapshot", + "browser_click", + "browser_type", + "browser_scroll", + "browser_back", + "browser_press", + "browser_close", + "browser_get_images", "browser_vision", # Text-to-speech "text_to_speech", # Planning & memory - "todo", "memory", + "todo", + "memory", # Session history search "session_search", # Clarifying questions "clarify", # Code execution + delegation - "execute_code", "delegate_task", + "execute_code", + "delegate_task", # Cronjob management - "schedule_cronjob", "list_cronjobs", "remove_cronjob", + "schedule_cronjob", + "list_cronjobs", + "remove_cronjob", # Cross-platform messaging (gated on gateway running via check_fn) "send_message", # Honcho user context (gated on honcho being active via check_fn) "query_user_context", # Home Assistant smart home control (gated on HASS_TOKEN via check_fn) - "ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service", + "ha_list_entities", + "ha_get_state", + "ha_list_services", + "ha_call_service", ] @@ -74,149 +94,125 @@ TOOLSETS = { "web": { "description": "Web research and content extraction tools", "tools": ["web_search", "web_extract"], - "includes": [] # No other toolsets included + "includes": [], # No other toolsets included }, - "search": { "description": "Web search only (no content extraction/scraping)", "tools": ["web_search"], - "includes": [] + "includes": [], }, - - "vision": { - "description": "Image analysis and vision tools", - "tools": ["vision_analyze"], - "includes": [] - }, - - "image_gen": { - "description": "Creative generation tools (images)", - "tools": ["image_generate"], - "includes": [] - }, - + "vision": {"description": "Image analysis and vision tools", "tools": ["vision_analyze"], "includes": []}, + "image_gen": {"description": "Creative generation tools (images)", "tools": ["image_generate"], "includes": []}, "terminal": { "description": "Terminal/command execution and process management tools", "tools": ["terminal", "process"], - "includes": [] + "includes": [], }, - "moa": { "description": "Advanced reasoning and problem-solving tools", "tools": ["mixture_of_agents"], - "includes": [] + "includes": [], }, - "skills": { "description": "Access, create, edit, and manage skill documents with specialized instructions and knowledge", "tools": ["skills_list", "skill_view", "skill_manage"], - "includes": [] + "includes": [], }, - "browser": { "description": "Browser automation for web interaction (navigate, click, type, scroll, iframes, hold-click) with web search for finding URLs", "tools": [ - "browser_navigate", "browser_snapshot", "browser_click", - "browser_type", "browser_scroll", "browser_back", - "browser_press", "browser_close", "browser_get_images", - "browser_vision", "web_search" + "browser_navigate", + "browser_snapshot", + "browser_click", + "browser_type", + "browser_scroll", + "browser_back", + "browser_press", + "browser_close", + "browser_get_images", + "browser_vision", + "web_search", ], - "includes": [] + "includes": [], }, - "cronjob": { "description": "Cronjob management tools - schedule, list, and remove automated tasks", "tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], - "includes": [] + "includes": [], }, - "rl": { "description": "RL training tools for running reinforcement learning on Tinker-Atropos", "tools": [ - "rl_list_environments", "rl_select_environment", - "rl_get_current_config", "rl_edit_config", - "rl_start_training", "rl_check_status", - "rl_stop_training", "rl_get_results", - "rl_list_runs", "rl_test_inference" + "rl_list_environments", + "rl_select_environment", + "rl_get_current_config", + "rl_edit_config", + "rl_start_training", + "rl_check_status", + "rl_stop_training", + "rl_get_results", + "rl_list_runs", + "rl_test_inference", ], - "includes": [] + "includes": [], }, - "file": { "description": "File manipulation tools: read, write, patch (with fuzzy matching), and search (content + files)", "tools": ["read_file", "write_file", "patch", "search_files"], - "includes": [] + "includes": [], }, - "tts": { "description": "Text-to-speech: convert text to audio with Edge TTS (free), ElevenLabs, or OpenAI", "tools": ["text_to_speech"], - "includes": [] + "includes": [], }, - - "todo": { - "description": "Task planning and tracking for multi-step work", - "tools": ["todo"], - "includes": [] - }, - + "todo": {"description": "Task planning and tracking for multi-step work", "tools": ["todo"], "includes": []}, "memory": { "description": "Persistent memory across sessions (personal notes + user profile)", "tools": ["memory"], - "includes": [] + "includes": [], }, - "session_search": { "description": "Search and recall past conversations with summarization", "tools": ["session_search"], - "includes": [] + "includes": [], }, - "clarify": { "description": "Ask the user clarifying questions (multiple-choice or open-ended)", "tools": ["clarify"], - "includes": [] + "includes": [], }, - "code_execution": { "description": "Run Python scripts that call tools programmatically (reduces LLM round trips)", "tools": ["execute_code"], - "includes": [] + "includes": [], }, - "delegation": { "description": "Spawn subagents with isolated context for complex subtasks", "tools": ["delegate_task"], - "includes": [] + "includes": [], }, - "honcho": { "description": "Honcho AI-native memory for persistent cross-session user modeling", "tools": ["query_user_context"], - "includes": [] + "includes": [], }, - "homeassistant": { "description": "Home Assistant smart home control and monitoring", "tools": ["ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service"], - "includes": [] + "includes": [], }, - - # Scenario-specific toolsets - "debugging": { "description": "Debugging and troubleshooting toolkit", "tools": ["terminal", "process"], - "includes": ["web", "file"] # For searching error messages and solutions, and file operations + "includes": ["web", "file"], # For searching error messages and solutions, and file operations }, - "safe": { "description": "Safe toolkit without terminal access", "tools": ["mixture_of_agents"], - "includes": ["web", "vision", "image_gen"] + "includes": ["web", "vision", "image_gen"], }, - # ========================================================================== # Full Hermes toolsets (CLI + messaging platforms) # @@ -224,65 +220,63 @@ TOOLSETS = { # All platforms share the same core tools (including send_message, # which is gated on gateway running via its check_fn). # ========================================================================== - "hermes-cli": { "description": "Full interactive CLI toolset - all default tools plus cronjob management", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-telegram": { "description": "Telegram bot toolset - full access for personal use (terminal has safety checks)", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-discord": { "description": "Discord bot toolset - full access (terminal has safety checks via dangerous command approval)", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-whatsapp": { "description": "WhatsApp bot toolset - similar to Telegram (personal messaging, more trusted)", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-slack": { "description": "Slack bot toolset - full access for workspace use (terminal has safety checks)", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-signal": { "description": "Signal bot toolset - encrypted messaging platform (full access)", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-homeassistant": { "description": "Home Assistant bot toolset - smart home event monitoring and control", "tools": _HERMES_CORE_TOOLS, - "includes": [] + "includes": [], }, - "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant"] - } + "includes": [ + "hermes-telegram", + "hermes-discord", + "hermes-whatsapp", + "hermes-slack", + "hermes-signal", + "hermes-homeassistant", + ], + }, } - -def get_toolset(name: str) -> Optional[Dict[str, Any]]: +def get_toolset(name: str) -> dict[str, Any] | None: """ Get a toolset definition by name. - + Args: name (str): Name of the toolset - + Returns: Dict: Toolset definition with description, tools, and includes None: If toolset not found @@ -291,27 +285,27 @@ def get_toolset(name: str) -> Optional[Dict[str, Any]]: return TOOLSETS.get(name) -def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: +def resolve_toolset(name: str, visited: set[str] = None) -> list[str]: """ Recursively resolve a toolset to get all tool names. - + This function handles toolset composition by recursively resolving included toolsets and combining all tools. - + Args: name (str): Name of the toolset to resolve visited (Set[str]): Set of already visited toolsets (for cycle detection) - + Returns: List[str]: List of all tool names in the toolset """ if visited is None: visited = set() - + # Special aliases that represent all tools across every toolset # This ensures future toolsets are automatically included without changes. if name in {"all", "*"}: - all_tools: Set[str] = set() + all_tools: set[str] = set() for toolset_name in get_toolset_names(): # Use a fresh visited set per branch to avoid cross-branch contamination resolved = resolve_toolset(toolset_name, visited.copy()) @@ -322,73 +316,71 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: if name in visited: print(f"⚠️ Circular dependency detected in toolset '{name}'") return [] - + visited.add(name) - + # Get toolset definition toolset = TOOLSETS.get(name) if not toolset: return [] - + # Collect direct tools tools = set(toolset.get("tools", [])) - + # Recursively resolve included toolsets for included_name in toolset.get("includes", []): included_tools = resolve_toolset(included_name, visited.copy()) tools.update(included_tools) - + return list(tools) -def resolve_multiple_toolsets(toolset_names: List[str]) -> List[str]: +def resolve_multiple_toolsets(toolset_names: list[str]) -> list[str]: """ Resolve multiple toolsets and combine their tools. - + Args: toolset_names (List[str]): List of toolset names to resolve - + Returns: List[str]: Combined list of all tool names (deduplicated) """ all_tools = set() - + for name in toolset_names: tools = resolve_toolset(name) all_tools.update(tools) - + return list(all_tools) -def get_all_toolsets() -> Dict[str, Dict[str, Any]]: +def get_all_toolsets() -> dict[str, dict[str, Any]]: """ Get all available toolsets with their definitions. - + Returns: Dict: All toolset definitions """ return TOOLSETS.copy() -def get_toolset_names() -> List[str]: +def get_toolset_names() -> list[str]: """ Get names of all available toolsets (excluding aliases). - + Returns: List[str]: List of toolset names """ return list(TOOLSETS.keys()) - - def validate_toolset(name: str) -> bool: """ Check if a toolset name is valid. - + Args: name (str): Toolset name to validate - + Returns: bool: True if valid, False otherwise """ @@ -398,46 +390,35 @@ def validate_toolset(name: str) -> bool: return name in TOOLSETS -def create_custom_toolset( - name: str, - description: str, - tools: List[str] = None, - includes: List[str] = None -) -> None: +def create_custom_toolset(name: str, description: str, tools: list[str] = None, includes: list[str] = None) -> None: """ Create a custom toolset at runtime. - + Args: name (str): Name for the new toolset description (str): Description of the toolset tools (List[str]): Direct tools to include includes (List[str]): Other toolsets to include """ - TOOLSETS[name] = { - "description": description, - "tools": tools or [], - "includes": includes or [] - } + TOOLSETS[name] = {"description": description, "tools": tools or [], "includes": includes or []} - - -def get_toolset_info(name: str) -> Dict[str, Any]: +def get_toolset_info(name: str) -> dict[str, Any]: """ Get detailed information about a toolset including resolved tools. - + Args: name (str): Toolset name - + Returns: Dict: Detailed toolset information """ toolset = get_toolset(name) if not toolset: return None - + resolved_tools = resolve_toolset(name) - + return { "name": name, "description": toolset["description"], @@ -445,32 +426,32 @@ def get_toolset_info(name: str) -> Dict[str, Any]: "includes": toolset["includes"], "resolved_tools": resolved_tools, "tool_count": len(resolved_tools), - "is_composite": len(toolset["includes"]) > 0 + "is_composite": len(toolset["includes"]) > 0, } def print_toolset_tree(name: str, indent: int = 0) -> None: """ Print a tree view of a toolset and its composition. - + Args: name (str): Toolset name indent (int): Current indentation level """ prefix = " " * indent toolset = get_toolset(name) - + if not toolset: print(f"{prefix}❌ Unknown toolset: {name}") return - + # Print toolset name and description print(f"{prefix}📦 {name}: {toolset['description']}") - + # Print direct tools if toolset["tools"]: print(f"{prefix} 🔧 Tools: {', '.join(toolset['tools'])}") - + # Print included toolsets if toolset["includes"]: print(f"{prefix} 📂 Includes:") @@ -481,7 +462,7 @@ def print_toolset_tree(name: str, indent: int = 0) -> None: if __name__ == "__main__": print("Toolsets System Demo") print("=" * 60) - + print("\nAvailable Toolsets:") print("-" * 40) for name, toolset in get_all_toolsets().items(): @@ -489,29 +470,29 @@ if __name__ == "__main__": composite = "[composite]" if info["is_composite"] else "[leaf]" print(f" {composite} {name:20} - {toolset['description']}") print(f" Tools: {len(info['resolved_tools'])} total") - + print("\nToolset Resolution Examples:") print("-" * 40) for name in ["web", "terminal", "safe", "debugging"]: tools = resolve_toolset(name) print(f"\n {name}:") print(f" Resolved to {len(tools)} tools: {', '.join(sorted(tools))}") - + print("\nMultiple Toolset Resolution:") print("-" * 40) combined = resolve_multiple_toolsets(["web", "vision", "terminal"]) - print(f" Combining ['web', 'vision', 'terminal']:") + print(" Combining ['web', 'vision', 'terminal']:") print(f" Result: {', '.join(sorted(combined))}") - + print("\nCustom Toolset Creation:") print("-" * 40) create_custom_toolset( name="my_custom", description="My custom toolset for specific tasks", tools=["web_search"], - includes=["terminal", "vision"] + includes=["terminal", "vision"], ) custom_info = get_toolset_info("my_custom") - print(f" Created 'my_custom' toolset:") + print(" Created 'my_custom' toolset:") print(f" Description: {custom_info['description']}") print(f" Resolved tools: {', '.join(custom_info['resolved_tools'])}") diff --git a/website/docs/developer-guide/contributing.md b/website/docs/developer-guide/contributing.md index f14ab9b400..c06a2d84b9 100644 --- a/website/docs/developer-guide/contributing.md +++ b/website/docs/developer-guide/contributing.md @@ -27,7 +27,7 @@ We value contributions in this order: | Requirement | Notes | |-------------|-------| | **Git** | With `--recurse-submodules` support | -| **Python 3.10+** | uv will install it if missing | +| **Python 3.11+** | uv will install it if missing | | **uv** | Fast Python package manager ([install](https://docs.astral.sh/uv/)) | | **Node.js 18+** | Optional — needed for browser tools and WhatsApp bridge | @@ -36,18 +36,7 @@ We value contributions in this order: ```bash git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git cd hermes-agent - -# Create venv with Python 3.11 -uv venv venv --python 3.11 -export VIRTUAL_ENV="$(pwd)/venv" - -# Install with all extras (messaging, cron, CLI menus, dev tools) -uv pip install -e ".[all,dev]" -uv pip install -e "./mini-swe-agent" -uv pip install -e "./tinker-atropos" - -# Optional: browser tools -npm install +make setup # creates .venv, installs all deps, sets up pre-commit ``` ### Configure for Development @@ -61,27 +50,21 @@ touch ~/.hermes/.env echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env ``` -### Run +### Common Commands ```bash -# Symlink for global access -mkdir -p ~/.local/bin -ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes - -# Verify -hermes doctor -hermes chat -q "Hello" -``` - -### Run Tests - -```bash -pytest tests/ -v +make test # run unit tests +make lint # ruff check +make fmt # ruff format + fix +make check # lint + test (same as CI) +make dev-cli # auto-restart hermes CLI on file changes +make dev-gateway # auto-restart gateway on file changes +make test-watch # rerun tests on file changes ``` ## Code Style -- **PEP 8** with practical exceptions (no strict line length enforcement) +- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically. - **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks - **Error handling**: Catch specific exceptions. Use `logger.warning()`/`logger.error()` with `exc_info=True` for unexpected errors - **Cross-platform**: Never assume Unix (see below) @@ -169,7 +152,7 @@ refactor/description # Code restructuring ### Before Submitting -1. **Run tests**: `pytest tests/ -v` +1. **Run checks**: `make check` (lint + test — same as CI) 2. **Test manually**: Run `hermes` and exercise the code path you changed 3. **Check cross-platform impact**: Consider macOS and different Linux distros 4. **Keep PRs focused**: One logical change per PR