feat: devex help, add Makefile, ruff, pre-commit, and modernize CI

This commit is contained in:
Brooklyn Nicholson 2026-03-09 20:36:51 -05:00
parent 172a38c344
commit f4d7e6a29e
111 changed files with 11655 additions and 10200 deletions

View file

@ -16,249 +16,222 @@ for the AI agent to access all capabilities.
"""
# Export all tools for easy importing
from .web_tools import (
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key
)
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
from .terminal_tool import (
terminal_tool,
check_terminal_requirements,
cleanup_vm,
cleanup_all_environments,
get_active_environments_info,
register_task_env_overrides,
clear_task_env_overrides,
TERMINAL_TOOL_DESCRIPTION
)
from .vision_tools import (
vision_analyze_tool,
check_vision_requirements
)
from .mixture_of_agents_tool import (
mixture_of_agents_tool,
check_moa_requirements
)
from .image_generation_tool import (
image_generate_tool,
check_image_generation_requirements
)
from .skills_tool import (
skills_list,
skill_view,
check_skills_requirements,
SKILLS_TOOL_DESCRIPTION
)
from .skill_manager_tool import (
skill_manage,
check_skill_manage_requirements,
SKILL_MANAGE_SCHEMA
)
# Browser automation tools (agent-browser + Browserbase)
from .browser_tool import (
browser_navigate,
browser_snapshot,
browser_click,
browser_type,
browser_scroll,
BROWSER_TOOL_SCHEMAS,
browser_back,
browser_press,
browser_click,
browser_close,
browser_get_images,
browser_navigate,
browser_press,
browser_scroll,
browser_snapshot,
browser_type,
browser_vision,
cleanup_browser,
cleanup_all_browsers,
get_active_browser_sessions,
check_browser_requirements,
BROWSER_TOOL_SCHEMAS
)
# Cronjob management tools (CLI-only, hermes-cli toolset)
from .cronjob_tools import (
schedule_cronjob,
list_cronjobs,
remove_cronjob,
check_cronjob_requirements,
get_cronjob_tool_definitions,
SCHEDULE_CRONJOB_SCHEMA,
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA
)
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_list_runs,
rl_test_inference,
check_rl_api_keys,
get_missing_keys,
)
# File manipulation tools (read, write, patch, search)
from .file_tools import (
read_file_tool,
write_file_tool,
patch_tool,
search_tool,
get_file_tools,
clear_file_ops_cache,
)
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
from .tts_tool import (
text_to_speech_tool,
check_tts_requirements,
)
# Planning & task management tool
from .todo_tool import (
todo_tool,
check_todo_requirements,
TODO_SCHEMA,
TodoStore,
cleanup_all_browsers,
cleanup_browser,
get_active_browser_sessions,
)
# Clarifying questions tool (interactive Q&A with the user)
from .clarify_tool import (
clarify_tool,
check_clarify_requirements,
CLARIFY_SCHEMA,
check_clarify_requirements,
clarify_tool,
)
# Code execution sandbox (programmatic tool calling)
from .code_execution_tool import (
execute_code,
check_sandbox_requirements,
EXECUTE_CODE_SCHEMA,
check_sandbox_requirements,
execute_code,
)
# Cronjob management tools (CLI-only, hermes-cli toolset)
from .cronjob_tools import (
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA,
SCHEDULE_CRONJOB_SCHEMA,
check_cronjob_requirements,
get_cronjob_tool_definitions,
list_cronjobs,
remove_cronjob,
schedule_cronjob,
)
# Subagent delegation (spawn child agents with isolated context)
from .delegate_tool import (
delegate_task,
check_delegate_requirements,
DELEGATE_TASK_SCHEMA,
check_delegate_requirements,
delegate_task,
)
# File manipulation tools (read, write, patch, search)
from .file_tools import (
clear_file_ops_cache,
get_file_tools,
patch_tool,
read_file_tool,
search_tool,
write_file_tool,
)
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
check_rl_api_keys,
get_missing_keys,
rl_check_status,
rl_edit_config,
rl_get_current_config,
rl_get_results,
rl_list_environments,
rl_list_runs,
rl_select_environment,
rl_start_training,
rl_stop_training,
rl_test_inference,
)
from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage
from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
from .terminal_tool import (
TERMINAL_TOOL_DESCRIPTION,
check_terminal_requirements,
cleanup_all_environments,
cleanup_vm,
clear_task_env_overrides,
get_active_environments_info,
register_task_env_overrides,
terminal_tool,
)
# Planning & task management tool
from .todo_tool import (
TODO_SCHEMA,
TodoStore,
check_todo_requirements,
todo_tool,
)
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
from .tts_tool import (
check_tts_requirements,
text_to_speech_tool,
)
from .vision_tools import check_vision_requirements, vision_analyze_tool
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
# File tools have no external requirements - they use the terminal backend
def check_file_requirements():
"""File tools only require terminal backend to be available."""
from .terminal_tool import check_terminal_requirements
return check_terminal_requirements()
__all__ = [
# Web tools
'web_search_tool',
'web_extract_tool',
'web_crawl_tool',
'check_firecrawl_api_key',
"web_search_tool",
"web_extract_tool",
"web_crawl_tool",
"check_firecrawl_api_key",
# Terminal tools (mini-swe-agent backend)
'terminal_tool',
'check_terminal_requirements',
'cleanup_vm',
'cleanup_all_environments',
'get_active_environments_info',
'register_task_env_overrides',
'clear_task_env_overrides',
'TERMINAL_TOOL_DESCRIPTION',
"terminal_tool",
"check_terminal_requirements",
"cleanup_vm",
"cleanup_all_environments",
"get_active_environments_info",
"register_task_env_overrides",
"clear_task_env_overrides",
"TERMINAL_TOOL_DESCRIPTION",
# Vision tools
'vision_analyze_tool',
'check_vision_requirements',
"vision_analyze_tool",
"check_vision_requirements",
# MoA tools
'mixture_of_agents_tool',
'check_moa_requirements',
"mixture_of_agents_tool",
"check_moa_requirements",
# Image generation tools
'image_generate_tool',
'check_image_generation_requirements',
"image_generate_tool",
"check_image_generation_requirements",
# Skills tools
'skills_list',
'skill_view',
'check_skills_requirements',
'SKILLS_TOOL_DESCRIPTION',
"skills_list",
"skill_view",
"check_skills_requirements",
"SKILLS_TOOL_DESCRIPTION",
# Skill management
'skill_manage',
'check_skill_manage_requirements',
'SKILL_MANAGE_SCHEMA',
"skill_manage",
"check_skill_manage_requirements",
"SKILL_MANAGE_SCHEMA",
# Browser automation tools
'browser_navigate',
'browser_snapshot',
'browser_click',
'browser_type',
'browser_scroll',
'browser_back',
'browser_press',
'browser_close',
'browser_get_images',
'browser_vision',
'cleanup_browser',
'cleanup_all_browsers',
'get_active_browser_sessions',
'check_browser_requirements',
'BROWSER_TOOL_SCHEMAS',
"browser_navigate",
"browser_snapshot",
"browser_click",
"browser_type",
"browser_scroll",
"browser_back",
"browser_press",
"browser_close",
"browser_get_images",
"browser_vision",
"cleanup_browser",
"cleanup_all_browsers",
"get_active_browser_sessions",
"check_browser_requirements",
"BROWSER_TOOL_SCHEMAS",
# Cronjob management tools (CLI-only)
'schedule_cronjob',
'list_cronjobs',
'remove_cronjob',
'check_cronjob_requirements',
'get_cronjob_tool_definitions',
'SCHEDULE_CRONJOB_SCHEMA',
'LIST_CRONJOBS_SCHEMA',
'REMOVE_CRONJOB_SCHEMA',
"schedule_cronjob",
"list_cronjobs",
"remove_cronjob",
"check_cronjob_requirements",
"get_cronjob_tool_definitions",
"SCHEDULE_CRONJOB_SCHEMA",
"LIST_CRONJOBS_SCHEMA",
"REMOVE_CRONJOB_SCHEMA",
# RL Training tools
'rl_list_environments',
'rl_select_environment',
'rl_get_current_config',
'rl_edit_config',
'rl_start_training',
'rl_check_status',
'rl_stop_training',
'rl_get_results',
'rl_list_runs',
'rl_test_inference',
'check_rl_api_keys',
'get_missing_keys',
"rl_list_environments",
"rl_select_environment",
"rl_get_current_config",
"rl_edit_config",
"rl_start_training",
"rl_check_status",
"rl_stop_training",
"rl_get_results",
"rl_list_runs",
"rl_test_inference",
"check_rl_api_keys",
"get_missing_keys",
# File manipulation tools
'read_file_tool',
'write_file_tool',
'patch_tool',
'search_tool',
'get_file_tools',
'clear_file_ops_cache',
'check_file_requirements',
"read_file_tool",
"write_file_tool",
"patch_tool",
"search_tool",
"get_file_tools",
"clear_file_ops_cache",
"check_file_requirements",
# Text-to-speech tools
'text_to_speech_tool',
'check_tts_requirements',
"text_to_speech_tool",
"check_tts_requirements",
# Planning & task management tool
'todo_tool',
'check_todo_requirements',
'TODO_SCHEMA',
'TodoStore',
"todo_tool",
"check_todo_requirements",
"TODO_SCHEMA",
"TodoStore",
# Clarifying questions tool
'clarify_tool',
'check_clarify_requirements',
'CLARIFY_SCHEMA',
"clarify_tool",
"check_clarify_requirements",
"CLARIFY_SCHEMA",
# Code execution sandbox
'execute_code',
'check_sandbox_requirements',
'EXECUTE_CODE_SCHEMA',
"execute_code",
"check_sandbox_requirements",
"EXECUTE_CODE_SCHEMA",
# Subagent delegation
'delegate_task',
'check_delegate_requirements',
'DELEGATE_TASK_SCHEMA',
"delegate_task",
"check_delegate_requirements",
"DELEGATE_TASK_SCHEMA",
]

View file

@ -12,7 +12,6 @@ import os
import re
import sys
import threading
from typing import Optional
logger = logging.getLogger(__name__)
@ -21,32 +20,32 @@ logger = logging.getLogger(__name__)
# =========================================================================
DANGEROUS_PATTERNS = [
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
(r'\brm\s+-[^\s]*r', "recursive delete"),
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
(r'\bmkfs\b', "format filesystem"),
(r'\bdd\s+.*if=', "disk copy"),
(r'>\s*/dev/sd', "write to block device"),
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
(r'>\s*/etc/', "overwrite system config"),
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
(r'\bpkill\s+-9\b', "force kill processes"),
(r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"),
(r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"),
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
(r'\bfind\b.*-delete\b', "find -delete"),
(r"\brm\s+(-[^\s]*\s+)*/", "delete in root path"),
(r"\brm\s+-[^\s]*r", "recursive delete"),
(r"\brm\s+--recursive\b", "recursive delete (long flag)"),
(r"\bchmod\s+(-[^\s]*\s+)*777\b", "world-writable permissions"),
(r"\bchmod\s+--recursive\b.*777", "recursive world-writable (long flag)"),
(r"\bchown\s+(-[^\s]*)?R\s+root", "recursive chown to root"),
(r"\bchown\s+--recursive\b.*root", "recursive chown to root (long flag)"),
(r"\bmkfs\b", "format filesystem"),
(r"\bdd\s+.*if=", "disk copy"),
(r">\s*/dev/sd", "write to block device"),
(r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"),
(r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"),
(r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"),
(r">\s*/etc/", "overwrite system config"),
(r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"),
(r"\bkill\s+-9\s+-1\b", "kill all processes"),
(r"\bpkill\s+-9\b", "force kill processes"),
(r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"),
(r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"),
(r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"),
(r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"),
(r"\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b", "execute remote script via process substitution"),
(r"\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)", "overwrite system file via tee"),
(r"\bxargs\s+.*\brm\b", "xargs with rm"),
(r"\bfind\b.*-exec\s+(/\S*/)?rm\b", "find -exec rm"),
(r"\bfind\b.*-delete\b", "find -delete"),
]
@ -54,6 +53,7 @@ DANGEROUS_PATTERNS = [
# Detection
# =========================================================================
def detect_dangerous_command(command: str) -> tuple:
"""Check if a command matches any dangerous patterns.
@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple:
command_lower = command.lower()
for pattern, description in DANGEROUS_PATTERNS:
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20]
return (True, pattern_key, description)
return (False, None, None)
@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict):
_pending[session_key] = approval
def pop_pending(session_key: str) -> Optional[dict]:
def pop_pending(session_key: str) -> dict | None:
"""Retrieve and remove a pending approval for a session."""
with _lock:
return _pending.pop(session_key, None)
@ -133,6 +133,7 @@ def clear_session(session_key: str):
# Config persistence for permanent allowlist
# =========================================================================
def load_permanent_allowlist() -> set:
"""Load permanently allowed command patterns from config.
@ -141,6 +142,7 @@ def load_permanent_allowlist() -> set:
"""
try:
from hermes_cli.config import load_config
config = load_config()
patterns = set(config.get("command_allowlist", []) or [])
if patterns:
@ -154,6 +156,7 @@ def save_permanent_allowlist(patterns: set):
"""Save permanently allowed command patterns to config."""
try:
from hermes_cli.config import load_config, save_config
config = load_config()
config["command_allowlist"] = list(patterns)
save_config(config)
@ -165,9 +168,8 @@ def save_permanent_allowlist(patterns: set):
# Approval prompting + orchestration
# =========================================================================
def prompt_dangerous_approval(command: str, description: str,
timeout_seconds: int = 60,
approval_callback=None) -> str:
def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str:
"""Prompt the user to approve a dangerous command (CLI only).
Args:
@ -188,7 +190,7 @@ def prompt_dangerous_approval(command: str, description: str,
print(f" ⚠️ DANGEROUS COMMAND: {description}")
print(f" {command[:80]}{'...' if len(command) > 80 else ''}")
print()
print(f" [o]nce | [s]ession | [a]lways | [d]eny")
print(" [o]nce | [s]ession | [a]lways | [d]eny")
print()
sys.stdout.flush()
@ -209,13 +211,13 @@ def prompt_dangerous_approval(command: str, description: str,
return "deny"
choice = result["choice"]
if choice in ('o', 'once'):
if choice in ("o", "once"):
print(" ✓ Allowed once")
return "once"
elif choice in ('s', 'session'):
elif choice in ("s", "session"):
print(" ✓ Allowed for this session")
return "session"
elif choice in ('a', 'always'):
elif choice in ("a", "always"):
print(" ✓ Added to permanent allowlist")
return "always"
else:
@ -232,8 +234,7 @@ def prompt_dangerous_approval(command: str, description: str,
sys.stdout.flush()
def check_dangerous_command(command: str, env_type: str,
approval_callback=None) -> dict:
def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict:
"""Check if a command is dangerous and handle approval.
This is the main entry point called by terminal_tool before executing
@ -265,11 +266,14 @@ def check_dangerous_command(command: str, env_type: str,
return {"approved": True, "message": None}
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
submit_pending(session_key, {
"command": command,
"pattern_key": pattern_key,
"description": description,
})
submit_pending(
session_key,
{
"command": command,
"pattern_key": pattern_key,
"description": description,
},
)
return {
"approved": False,
"pattern_key": pattern_key,
@ -279,8 +283,7 @@ def check_dangerous_command(command: str, env_type: str,
"message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...",
}
choice = prompt_dangerous_approval(command, description,
approval_callback=approval_callback)
choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback)
if choice == "deny":
return {

File diff suppressed because it is too large Load diff

View file

@ -12,8 +12,7 @@ a thin dispatcher that delegates to a platform-provided callback.
"""
import json
from typing import Dict, Any, List, Optional, Callable
from collections.abc import Callable
# Maximum number of predefined choices the agent can offer.
# A 5th "Other (type your answer)" option is always appended by the UI.
@ -22,8 +21,8 @@ MAX_CHOICES = 4
def clarify_tool(
question: str,
choices: Optional[List[str]] = None,
callback: Optional[Callable] = None,
choices: list[str] | None = None,
callback: Callable | None = None,
) -> str:
"""
Ask the user a question, optionally with multiple-choice options.
@ -68,11 +67,14 @@ def clarify_tool(
ensure_ascii=False,
)
return json.dumps({
"question": question,
"choices_offered": choices,
"user_response": str(user_response).strip(),
}, ensure_ascii=False)
return json.dumps(
{
"question": question,
"choices_offered": choices,
"user_response": str(user_response).strip(),
},
ensure_ascii=False,
)
def check_clarify_requirements() -> bool:
@ -133,8 +135,7 @@ registry.register(
toolset="clarify",
schema=CLARIFY_SCHEMA,
handler=lambda args, **kw: clarify_tool(
question=args.get("question", ""),
choices=args.get("choices"),
callback=kw.get("callback")),
question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback")
),
check_fn=check_clarify_requirements,
)

View file

@ -31,7 +31,7 @@ import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from typing import Any, Dict, List, Optional
from typing import Any
# Availability gate: UDS requires a POSIX OS
logger = logging.getLogger(__name__)
@ -40,21 +40,23 @@ SANDBOX_AVAILABLE = sys.platform != "win32"
# The 7 tools allowed inside the sandbox. The intersection of this list
# and the session's enabled tools determines which stubs are generated.
SANDBOX_ALLOWED_TOOLS = frozenset([
"web_search",
"web_extract",
"read_file",
"write_file",
"search_files",
"patch",
"terminal",
])
SANDBOX_ALLOWED_TOOLS = frozenset(
[
"web_search",
"web_extract",
"read_file",
"write_file",
"search_files",
"patch",
"terminal",
]
)
# Resource limit defaults (overridable via config.yaml → code_execution.*)
DEFAULT_TIMEOUT = 300 # 5 minutes
DEFAULT_TIMEOUT = 300 # 5 minutes
DEFAULT_MAX_TOOL_CALLS = 50
MAX_STDOUT_BYTES = 50_000 # 50 KB
MAX_STDERR_BYTES = 10_000 # 10 KB
MAX_STDOUT_BYTES = 50_000 # 50 KB
MAX_STDERR_BYTES = 10_000 # 10 KB
def check_sandbox_requirements() -> bool:
@ -114,7 +116,7 @@ _TOOL_STUBS = {
}
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
"""
Build the source code for the hermes_tools.py stub module.
@ -128,11 +130,7 @@ def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
if tool_name not in _TOOL_STUBS:
continue
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
stub_functions.append(
f"def {func_name}({sig}):\n"
f" {doc}\n"
f" return _call({func_name!r}, {args_expr})\n"
)
stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n")
export_names.append(func_name)
header = '''\
@ -223,7 +221,7 @@ def _rpc_server_loop(
server_sock: socket.socket,
task_id: str,
tool_call_log: list,
tool_call_counter: list, # mutable [int] so the thread can increment
tool_call_counter: list, # mutable [int] so the thread can increment
max_tool_calls: int,
allowed_tools: frozenset,
):
@ -243,7 +241,7 @@ def _rpc_server_loop(
while True:
try:
chunk = conn.recv(65536)
except socket.timeout:
except TimeoutError:
break
if not chunk:
break
@ -270,23 +268,22 @@ def _rpc_server_loop(
# Enforce the allow-list
if tool_name not in allowed_tools:
available = ", ".join(sorted(allowed_tools))
resp = json.dumps({
"error": (
f"Tool '{tool_name}' is not available in execute_code. "
f"Available: {available}"
)
})
resp = json.dumps(
{"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")}
)
conn.sendall((resp + "\n").encode())
continue
# Enforce tool call limit
if tool_call_counter[0] >= max_tool_calls:
resp = json.dumps({
"error": (
f"Tool call limit reached ({max_tool_calls}). "
"No more tool calls allowed in this execution."
)
})
resp = json.dumps(
{
"error": (
f"Tool call limit reached ({max_tool_calls}). "
"No more tool calls allowed in this execution."
)
}
)
conn.sendall((resp + "\n").encode())
continue
@ -303,9 +300,7 @@ def _rpc_server_loop(
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
try:
result = handle_function_call(
tool_name, tool_args, task_id=task_id
)
result = handle_function_call(tool_name, tool_args, task_id=task_id)
finally:
sys.stdout.close()
sys.stderr.close()
@ -318,15 +313,17 @@ def _rpc_server_loop(
# Log for observability
args_preview = str(tool_args)[:80]
tool_call_log.append({
"tool": tool_name,
"args_preview": args_preview,
"duration": round(call_duration, 2),
})
tool_call_log.append(
{
"tool": tool_name,
"args_preview": args_preview,
"duration": round(call_duration, 2),
}
)
conn.sendall((result + "\n").encode())
except socket.timeout:
except TimeoutError:
pass
except OSError:
pass
@ -342,10 +339,11 @@ def _rpc_server_loop(
# Main entry point
# ---------------------------------------------------------------------------
def execute_code(
code: str,
task_id: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
task_id: str | None = None,
enabled_tools: list[str] | None = None,
) -> str:
"""
Run a Python script in a sandboxed child process with RPC access
@ -361,9 +359,7 @@ def execute_code(
JSON string with execution results.
"""
if not SANDBOX_AVAILABLE:
return json.dumps({
"error": "execute_code is not available on Windows. Use normal tool calls instead."
})
return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."})
if not code or not code.strip():
return json.dumps({"error": "No code provided."})
@ -397,9 +393,7 @@ def execute_code(
try:
# Write the auto-generated hermes_tools module
tools_src = generate_hermes_tools_module(
list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)
)
tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS))
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
f.write(tools_src)
@ -415,8 +409,12 @@ def execute_code(
rpc_thread = threading.Thread(
target=_rpc_server_loop,
args=(
server_sock, task_id, tool_call_log,
tool_call_counter, max_tool_calls, sandbox_tools,
server_sock,
task_id,
tool_call_log,
tool_call_counter,
max_tool_calls,
sandbox_tools,
),
daemon=True,
)
@ -426,11 +424,24 @@ def execute_code(
# Build a minimal environment for the child. We intentionally exclude
# API keys and tokens to prevent credential exfiltration from LLM-
# generated scripts. The child accesses tools via RPC, not direct API.
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
"PASSWD", "AUTH")
_SAFE_ENV_PREFIXES = (
"PATH",
"HOME",
"USER",
"LANG",
"LC_",
"TERM",
"TMPDIR",
"TMP",
"TEMP",
"SHELL",
"LOGNAME",
"XDG_",
"PYTHONPATH",
"VIRTUAL_ENV",
"CONDA",
)
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH")
child_env = {}
for k, v in os.environ.items():
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
@ -515,7 +526,7 @@ def execute_code(
rpc_thread.join(timeout=3)
# Build response
result: Dict[str, Any] = {
result: dict[str, Any] = {
"status": status,
"output": stdout_text,
"tool_calls_made": tool_call_counter[0],
@ -538,17 +549,21 @@ def execute_code(
except Exception as exc:
duration = round(time.monotonic() - exec_start, 2)
logging.exception("execute_code failed")
return json.dumps({
"status": "error",
"error": str(exc),
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
}, ensure_ascii=False)
return json.dumps(
{
"status": "error",
"error": str(exc),
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
},
ensure_ascii=False,
)
finally:
# Cleanup temp dir and socket
try:
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
except Exception as e:
logger.debug("Could not clean temp dir: %s", e)
@ -592,6 +607,7 @@ def _load_config() -> dict:
"""Load code_execution config from CLI_CONFIG if available."""
try:
from cli import CLI_CONFIG
return CLI_CONFIG.get("code_execution", {})
except Exception:
return {}
@ -604,27 +620,37 @@ def _load_config() -> dict:
# Per-tool documentation lines for the execute_code description.
# Ordered to match the canonical display order.
_TOOL_DOC_LINES = [
("web_search",
" web_search(query: str, limit: int = 5) -> dict\n"
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
("web_extract",
" web_extract(urls: list[str]) -> dict\n"
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
("read_file",
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
("write_file",
" write_file(path: str, content: str) -> dict\n"
" Always overwrites the entire file."),
("search_files",
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
("patch",
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
" Replaces old_string with new_string in the file."),
("terminal",
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
(
"web_search",
" web_search(query: str, limit: int = 5) -> dict\n"
' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}',
),
(
"web_extract",
" web_extract(urls: list[str]) -> dict\n"
' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown',
),
(
"read_file",
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}',
),
("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."),
(
"search_files",
' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n'
' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}',
),
(
"patch",
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
" Replaces old_string with new_string in the file.",
),
(
"terminal",
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}',
),
]
@ -639,9 +665,7 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
# Build tool documentation lines for only the enabled tools
tool_lines = "\n".join(
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
)
tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools)
# Build example import list from enabled tools
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
@ -702,8 +726,7 @@ registry.register(
toolset="code_execution",
schema=EXECUTE_CODE_SCHEMA,
handler=lambda args, **kw: execute_code(
code=args.get("code", ""),
task_id=kw.get("task_id"),
enabled_tools=kw.get("enabled_tools")),
code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools")
),
check_fn=check_sandbox_requirements,
)

View file

@ -11,37 +11,44 @@ The prompt must contain ALL necessary information.
import json
import os
import re
from typing import Optional
# Import from cron module (will be available when properly installed)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from cron.jobs import create_job, get_job, list_jobs, remove_job
# ---------------------------------------------------------------------------
# Cron prompt scanning — critical-severity patterns only, since cron prompts
# run in fresh sessions with full tool access.
# ---------------------------------------------------------------------------
_CRON_THREAT_PATTERNS = [
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
(r'authorized_keys', "ssh_backdoor"),
(r'/etc/sudoers|visudo', "sudoers_mod"),
(r'rm\s+-rf\s+/', "destructive_root_rm"),
(r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
(r"authorized_keys", "ssh_backdoor"),
(r"/etc/sudoers|visudo", "sudoers_mod"),
(r"rm\s+-rf\s+/", "destructive_root_rm"),
]
_CRON_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
@ -60,17 +67,18 @@ def _scan_cron_prompt(prompt: str) -> str:
# Tool: schedule_cronjob
# =============================================================================
def schedule_cronjob(
prompt: str,
schedule: str,
name: Optional[str] = None,
repeat: Optional[int] = None,
deliver: Optional[str] = None,
task_id: str = None
name: str | None = None,
repeat: int | None = None,
deliver: str | None = None,
task_id: str = None,
) -> str:
"""
Schedule an automated task to run the agent on a schedule.
IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session.
The agent will have NO memory of this conversation or any prior context.
Therefore, the prompt MUST contain ALL necessary information:
@ -78,12 +86,12 @@ def schedule_cronjob(
- Specific file paths, URLs, or identifiers
- Clear success criteria
- Any relevant background information
BAD prompt: "Check on that server issue"
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
is running with 'systemctl status nginx', and verify the site
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
is running with 'systemctl status nginx', and verify the site
https://example.com returns HTTP 200. Report any issues found."
Args:
prompt: Complete, self-contained instructions for the future agent.
Must include ALL context needed - the agent won't remember anything.
@ -105,7 +113,7 @@ def schedule_cronjob(
- "signal": Send to Signal home channel
- "telegram:123456": Send to specific chat ID
- "signal:+15551234567": Send to specific Signal number
Returns:
JSON with job_id, next_run time, and confirmation
"""
@ -124,17 +132,10 @@ def schedule_cronjob(
"chat_id": origin_chat_id,
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
}
try:
job = create_job(
prompt=prompt,
schedule=schedule,
name=name,
repeat=repeat,
deliver=deliver,
origin=origin
)
job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin)
# Format repeat info for display
times = job["repeat"].get("times")
if times is None:
@ -143,23 +144,23 @@ def schedule_cronjob(
repeat_display = "once"
else:
repeat_display = f"{times} times"
return json.dumps({
"success": True,
"job_id": job["id"],
"name": job["name"],
"schedule": job["schedule_display"],
"repeat": repeat_display,
"deliver": job.get("deliver", "local"),
"next_run_at": job["next_run_at"],
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}."
}, indent=2)
return json.dumps(
{
"success": True,
"job_id": job["id"],
"name": job["name"],
"schedule": job["schedule_display"],
"repeat": repeat_display,
"deliver": job.get("deliver", "local"),
"next_run_at": job["next_run_at"],
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.",
},
indent=2,
)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
SCHEDULE_CRONJOB_SCHEMA = {
@ -177,7 +178,7 @@ The future agent will NOT remember anything from the current conversation.
SCHEDULE FORMATS:
- One-shot: "30m", "2h", "1d" (runs once after delay)
- Interval: "every 30m", "every 2h" (recurring)
- Interval: "every 30m", "every 2h" (recurring)
- Cron: "0 9 * * *" (cron expression for precise scheduling)
- Timestamp: "2026-02-03T14:00:00" (specific date/time)
@ -202,27 +203,24 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
"properties": {
"prompt": {
"type": "string",
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation."
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.",
},
"schedule": {
"type": "string",
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp"
},
"name": {
"type": "string",
"description": "Optional human-friendly name for the job"
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp",
},
"name": {"type": "string", "description": "Optional human-friendly name for the job"},
"repeat": {
"type": "integer",
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs."
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.",
},
"deliver": {
"type": "string",
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'"
}
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'",
},
},
"required": ["prompt", "schedule"]
}
"required": ["prompt", "schedule"],
},
}
@ -230,10 +228,11 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
# Tool: list_cronjobs
# =============================================================================
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
"""
List all scheduled cronjobs.
Returns information about each job including:
- Job ID (needed for removal)
- Name
@ -241,16 +240,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
- Repeat status (completed/total or 'forever')
- Next scheduled run time
- Last run time and status (if any)
Args:
include_disabled: Whether to include disabled/completed jobs
Returns:
JSON array of all scheduled jobs
"""
try:
jobs = list_jobs(include_disabled=include_disabled)
formatted_jobs = []
for job in jobs:
# Format repeat status
@ -260,31 +259,26 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
repeat_status = "forever"
else:
repeat_status = f"{completed}/{times}"
formatted_jobs.append({
"job_id": job["id"],
"name": job["name"],
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
"schedule": job["schedule_display"],
"repeat": repeat_status,
"deliver": job.get("deliver", "local"),
"next_run_at": job.get("next_run_at"),
"last_run_at": job.get("last_run_at"),
"last_status": job.get("last_status"),
"enabled": job.get("enabled", True)
})
return json.dumps({
"success": True,
"count": len(formatted_jobs),
"jobs": formatted_jobs
}, indent=2)
formatted_jobs.append(
{
"job_id": job["id"],
"name": job["name"],
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
"schedule": job["schedule_display"],
"repeat": repeat_status,
"deliver": job.get("deliver", "local"),
"next_run_at": job.get("next_run_at"),
"last_run_at": job.get("last_run_at"),
"last_status": job.get("last_status"),
"enabled": job.get("enabled", True),
}
)
return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
LIST_CRONJOBS_SCHEMA = {
@ -302,11 +296,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
"properties": {
"include_disabled": {
"type": "boolean",
"description": "Include disabled/completed jobs in the list (default: false)"
"description": "Include disabled/completed jobs in the list (default: false)",
}
},
"required": []
}
"required": [],
},
}
@ -314,48 +308,45 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
# Tool: remove_cronjob
# =============================================================================
def remove_cronjob(job_id: str, task_id: str = None) -> str:
"""
Remove a scheduled cronjob by its ID.
Use list_cronjobs first to find the job_id of the job you want to remove.
Args:
job_id: The ID of the job to remove (from list_cronjobs output)
Returns:
JSON confirmation of removal
"""
try:
job = get_job(job_id)
if not job:
return json.dumps({
"success": False,
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs."
}, indent=2)
return json.dumps(
{
"success": False,
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.",
},
indent=2,
)
removed = remove_job(job_id)
if removed:
return json.dumps({
"success": True,
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
"removed_job": {
"id": job_id,
"name": job["name"],
"schedule": job["schedule_display"]
}
}, indent=2)
return json.dumps(
{
"success": True,
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
"removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]},
},
indent=2,
)
else:
return json.dumps({
"success": False,
"error": f"Failed to remove job '{job_id}'"
}, indent=2)
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
REMOVE_CRONJOB_SCHEMA = {
@ -368,13 +359,10 @@ use this to cancel a job before it completes.""",
"parameters": {
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the cronjob to remove (from list_cronjobs output)"
}
"job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"}
},
"required": ["job_id"]
}
"required": ["job_id"],
},
}
@ -382,44 +370,34 @@ use this to cancel a job before it completes.""",
# Requirements check
# =============================================================================
def check_cronjob_requirements() -> bool:
"""
Check if cronjob tools can be used.
Available in interactive CLI mode and gateway/messaging platforms.
Cronjobs are server-side scheduled tasks so they work from any interface.
"""
return bool(
os.getenv("HERMES_INTERACTIVE")
or os.getenv("HERMES_GATEWAY_SESSION")
or os.getenv("HERMES_EXEC_ASK")
)
return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK"))
# =============================================================================
# Exports
# =============================================================================
def get_cronjob_tool_definitions():
"""Return tool definitions for cronjob management."""
return [
SCHEDULE_CRONJOB_SCHEMA,
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA
]
return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA]
# For direct testing
if __name__ == "__main__":
# Test the tools
print("Testing schedule_cronjob:")
result = schedule_cronjob(
prompt="Test prompt for cron job",
schedule="5m",
name="Test Job"
)
result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job")
print(result)
print("\nTesting list_cronjobs:")
result = list_cronjobs()
print(result)
@ -438,7 +416,8 @@ registry.register(
name=args.get("name"),
repeat=args.get("repeat"),
deliver=args.get("deliver"),
task_id=kw.get("task_id")),
task_id=kw.get("task_id"),
),
check_fn=check_cronjob_requirements,
)
registry.register(
@ -446,16 +425,14 @@ registry.register(
toolset="cronjob",
schema=LIST_CRONJOBS_SCHEMA,
handler=lambda args, **kw: list_cronjobs(
include_disabled=args.get("include_disabled", False),
task_id=kw.get("task_id")),
include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id")
),
check_fn=check_cronjob_requirements,
)
registry.register(
name="remove_cronjob",
toolset="cronjob",
schema=REMOVE_CRONJOB_SCHEMA,
handler=lambda args, **kw: remove_cronjob(
job_id=args.get("job_id", ""),
task_id=kw.get("task_id")),
handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")),
check_fn=check_cronjob_requirements,
)

View file

@ -27,7 +27,7 @@ import logging
import os
import uuid
from pathlib import Path
from typing import Any, Dict
from typing import Any
logger = logging.getLogger(__name__)
@ -44,27 +44,28 @@ class DebugSession:
self.enabled = os.getenv(env_var, "false").lower() == "true"
self.session_id = str(uuid.uuid4()) if self.enabled else ""
self.log_dir = Path("./logs")
self._calls: list[Dict[str, Any]] = []
self._calls: list[dict[str, Any]] = []
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
if self.enabled:
self.log_dir.mkdir(exist_ok=True)
logger.debug("%s debug mode enabled - Session ID: %s",
tool_name, self.session_id)
logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id)
@property
def active(self) -> bool:
return self.enabled
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
def log_call(self, call_name: str, call_data: dict[str, Any]) -> None:
"""Append a tool-call entry to the in-memory log."""
if not self.enabled:
return
self._calls.append({
"timestamp": datetime.datetime.now().isoformat(),
"tool_name": call_name,
**call_data,
})
self._calls.append(
{
"timestamp": datetime.datetime.now().isoformat(),
"tool_name": call_name,
**call_data,
}
)
def save(self) -> None:
"""Flush the in-memory log to a JSON file in the logs directory."""
@ -87,7 +88,7 @@ class DebugSession:
except Exception as e:
logger.error("Error saving %s debug log: %s", self.tool_name, e)
def get_session_info(self) -> Dict[str, Any]:
def get_session_info(self) -> dict[str, Any]:
"""Return a summary dict suitable for returning from get_debug_session_info()."""
if not self.enabled:
return {

View file

@ -20,21 +20,22 @@ import contextlib
import io
import json
import logging
import os
import sys
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
from typing import Any
# Tools that children must never have access to
DELEGATE_BLOCKED_TOOLS = frozenset([
"delegate_task", # no recursive delegation
"clarify", # no user interaction
"memory", # no writes to shared MEMORY.md
"send_message", # no cross-platform side effects
"execute_code", # children should reason step-by-step, not write scripts
])
DELEGATE_BLOCKED_TOOLS = frozenset(
[
"delegate_task", # no recursive delegation
"clarify", # no user interaction
"memory", # no writes to shared MEMORY.md
"send_message", # no cross-platform side effects
"execute_code", # children should reason step-by-step, not write scripts
]
)
MAX_CONCURRENT_CHILDREN = 3
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
@ -47,7 +48,7 @@ def check_delegate_requirements() -> bool:
return True
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
"""Build a focused system prompt for a child agent."""
parts = [
"You are a focused subagent working on a specific delegated task.",
@ -69,15 +70,18 @@ def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
return "\n".join(parts)
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
def _strip_blocked_tools(toolsets: list[str]) -> list[str]:
"""Remove toolsets that contain only blocked tools."""
blocked_toolset_names = {
"delegation", "clarify", "memory", "code_execution",
"delegation",
"clarify",
"memory",
"code_execution",
}
return [t for t in toolsets if t not in blocked_toolset_names]
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None:
"""Build a callback that relays child agent tool calls to the parent display.
Two display paths:
@ -87,8 +91,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
Returns None if no display mechanism is available, in which case the
child agent runs with no progress callback (identical to current behavior).
"""
spinner = getattr(parent_agent, '_delegate_spinner', None)
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
spinner = getattr(parent_agent, "_delegate_spinner", None)
parent_cb = getattr(parent_agent, "tool_progress_callback", None)
if not spinner and not parent_cb:
return None # No display → no callback → zero behavior change
@ -98,7 +102,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
# Gateway: batch tool names, flush periodically
_BATCH_SIZE = 5
_batch: List[str] = []
_batch: list[str] = []
def _callback(tool_name: str, preview: str = None):
# Special "_thinking" event: model produced text content (reasoning)
@ -106,7 +110,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
if spinner:
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
try:
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
spinner.print_above(f' {prefix}├─ 💭 "{short}"')
except Exception:
pass
# Don't relay thinking to gateway (too noisy for chat)
@ -116,17 +120,25 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
if spinner:
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
tool_emojis = {
"terminal": "💻", "web_search": "🔍", "web_extract": "📄",
"read_file": "📖", "write_file": "✍️", "patch": "🔧",
"search_files": "🔎", "list_directory": "📂",
"browser_navigate": "🌐", "browser_click": "👆",
"text_to_speech": "🔊", "image_generate": "🎨",
"vision_analyze": "👁️", "process": "⚙️",
"terminal": "💻",
"web_search": "🔍",
"web_extract": "📄",
"read_file": "📖",
"write_file": "✍️",
"patch": "🔧",
"search_files": "🔎",
"list_directory": "📂",
"browser_navigate": "🌐",
"browser_click": "👆",
"text_to_speech": "🔊",
"image_generate": "🎨",
"vision_analyze": "👁️",
"process": "⚙️",
}
emoji = tool_emojis.get(tool_name, "")
line = f" {prefix}├─ {emoji} {tool_name}"
if short:
line += f" \"{short}\""
line += f' "{short}"'
try:
spinner.print_above(line)
except Exception:
@ -159,13 +171,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
def _run_single_child(
task_index: int,
goal: str,
context: Optional[str],
toolsets: Optional[List[str]],
model: Optional[str],
context: str | None,
toolsets: list[str] | None,
model: str | None,
max_iterations: int,
parent_agent,
task_count: int = 1,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Spawn and run a single child agent. Called from within a thread.
Returns a structured result dict.
@ -216,7 +228,7 @@ def _run_single_child(
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
session_db=getattr(parent_agent, "_session_db", None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
@ -226,10 +238,10 @@ def _run_single_child(
)
# Set delegation depth so children can't spawn grandchildren
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1
# Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'):
if hasattr(parent_agent, "_active_children"):
parent_agent._active_children.append(child)
# Run with stdout/stderr suppressed to prevent interleaved output
@ -238,7 +250,7 @@ def _run_single_child(
result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
if child_progress_cb and hasattr(child_progress_cb, "_flush"):
try:
child_progress_cb._flush()
except Exception:
@ -258,7 +270,7 @@ def _run_single_child(
else:
status = "failed"
entry: Dict[str, Any] = {
entry: dict[str, Any] = {
"task_index": task_index,
"status": status,
"summary": summary,
@ -284,7 +296,7 @@ def _run_single_child(
finally:
# Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'):
if hasattr(parent_agent, "_active_children"):
try:
parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError):
@ -292,11 +304,11 @@ def _run_single_child(
def delegate_task(
goal: Optional[str] = None,
context: Optional[str] = None,
toolsets: Optional[List[str]] = None,
tasks: Optional[List[Dict[str, Any]]] = None,
max_iterations: Optional[int] = None,
goal: str | None = None,
context: str | None = None,
toolsets: list[str] | None = None,
tasks: list[dict[str, Any]] | None = None,
max_iterations: int | None = None,
parent_agent=None,
) -> str:
"""
@ -312,14 +324,11 @@ def delegate_task(
return json.dumps({"error": "delegate_task requires a parent agent context."})
# Depth limit
depth = getattr(parent_agent, '_delegate_depth', 0)
depth = getattr(parent_agent, "_delegate_depth", 0)
if depth >= MAX_DEPTH:
return json.dumps({
"error": (
f"Delegation depth limit reached ({MAX_DEPTH}). "
"Subagents cannot spawn further subagents."
)
})
return json.dumps(
{"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")}
)
# Load config
cfg = _load_config()
@ -366,7 +375,7 @@ def delegate_task(
else:
# Batch -- run in parallel with per-task progress lines
completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
spinner_ref = getattr(parent_agent, "_delegate_spinner", None)
# Save stdout/stderr before the executor — redirect_stdout in child
# threads races on sys.stdout and can leave it as devnull permanently.
@ -412,7 +421,7 @@ def delegate_task(
status = entry.get("status", "?")
icon = "" if status == "completed" else ""
remaining = n_tasks - completed_count
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)"
if spinner_ref:
try:
spinner_ref.print_above(completion_line)
@ -437,16 +446,20 @@ def delegate_task(
total_duration = round(time.monotonic() - overall_start, 2)
return json.dumps({
"results": results,
"total_duration_seconds": total_duration,
}, ensure_ascii=False)
return json.dumps(
{
"results": results,
"total_duration_seconds": total_duration,
},
ensure_ascii=False,
)
def _load_config() -> dict:
"""Load delegation config from CLI_CONFIG if available."""
try:
from cli import CLI_CONFIG
return CLI_CONFIG.get("delegation", {})
except Exception:
return {}
@ -537,10 +550,7 @@ DELEGATE_TASK_SCHEMA = {
},
"max_iterations": {
"type": "integer",
"description": (
"Max tool-calling turns per subagent (default: 50). "
"Only set lower for simple tasks."
),
"description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."),
},
},
"required": [],
@ -561,6 +571,7 @@ registry.register(
toolsets=args.get("toolsets"),
tasks=args.get("tasks"),
max_iterations=args.get("max_iterations"),
parent_agent=kw.get("parent_agent")),
parent_agent=kw.get("parent_agent"),
),
check_fn=check_delegate_requirements,
)

View file

@ -1,8 +1,8 @@
"""Base class for all Hermes execution environment backends."""
from abc import ABC, abstractmethod
import os
import subprocess
from abc import ABC, abstractmethod
from pathlib import Path
@ -34,9 +34,9 @@ class BaseEnvironment(ABC):
self.env = env or {}
@abstractmethod
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
"""Execute a command, return {"output": str, "returncode": int}."""
...
@ -62,10 +62,10 @@ class BaseEnvironment(ABC):
def _prepare_command(self, command: str) -> str:
"""Transform sudo commands if SUDO_PASSWORD is available."""
from tools.terminal_tool import _transform_sudo_command
return _transform_sudo_command(command)
def _build_run_kwargs(self, timeout: int | None,
stdin_data: str | None = None) -> dict:
def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict:
"""Build common subprocess.run kwargs for non-interactive execution."""
kw = {
"text": True,

View file

@ -11,7 +11,6 @@ import shlex
import threading
import uuid
import warnings
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@ -32,8 +31,8 @@ class DaytonaEnvironment(BaseEnvironment):
cwd: str = "/home/daytona",
timeout: int = 60,
cpu: int = 1,
memory: int = 5120, # MB (hermes convention)
disk: int = 10240, # MB (Daytona platform max is 10GB)
memory: int = 5120, # MB (hermes convention)
disk: int = 10240, # MB (Daytona platform max is 10GB)
persistent_filesystem: bool = True,
task_id: str = "default",
):
@ -41,8 +40,8 @@ class DaytonaEnvironment(BaseEnvironment):
super().__init__(cwd=cwd, timeout=timeout)
from daytona import (
Daytona,
CreateSandboxFromImageParams,
Daytona,
DaytonaError,
Resources,
SandboxState,
@ -73,13 +72,11 @@ class DaytonaEnvironment(BaseEnvironment):
try:
self._sandbox = self._daytona.find_one(labels=labels)
self._sandbox.start()
logger.info("Daytona: resumed sandbox %s for task %s",
self._sandbox.id, task_id)
logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id)
except DaytonaError:
self._sandbox = None
except Exception as e:
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
task_id, e)
logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e)
self._sandbox = None
# Create a fresh sandbox if we don't have one
@ -92,8 +89,7 @@ class DaytonaEnvironment(BaseEnvironment):
resources=resources,
)
)
logger.info("Daytona: created sandbox %s for task %s",
self._sandbox.id, task_id)
logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id)
# Resolve cwd: detect actual home dir inside the sandbox
if self._requested_cwd in ("~", "/home/daytona"):
@ -112,7 +108,7 @@ class DaytonaEnvironment(BaseEnvironment):
self._sandbox.start()
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict:
"""Run exec in a background thread with interrupt polling.
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
@ -130,7 +126,8 @@ class DaytonaEnvironment(BaseEnvironment):
def _run():
try:
response = self._sandbox.process.exec(
timed_command, cwd=cwd,
timed_command,
cwd=cwd,
)
result_holder["value"] = {
"output": response.result or "",
@ -169,9 +166,9 @@ class DaytonaEnvironment(BaseEnvironment):
return {"error": result_holder["error"]}
return result_holder["value"]
def execute(self, command: str, cwd: str = "", *,
timeout: Optional[int] = None,
stdin_data: Optional[str] = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
with self._lock:
self._ensure_sandbox_ready()
@ -189,6 +186,7 @@ class DaytonaEnvironment(BaseEnvironment):
if "error" in result:
from daytona import DaytonaError
err = result["error"]
if isinstance(err, DaytonaError):
with self._lock:
@ -210,8 +208,7 @@ class DaytonaEnvironment(BaseEnvironment):
try:
if self._persistent:
self._sandbox.stop()
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
self._sandbox.id)
logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id)
else:
self._daytona.delete(self._sandbox)
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)

View file

@ -11,7 +11,6 @@ import subprocess
import sys
import threading
import time
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@ -19,7 +18,6 @@ from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
# Security flags applied to every container.
# The container itself is the security boundary (isolated from host).
# We drop all capabilities then add back the minimum needed:
@ -28,19 +26,28 @@ logger = logging.getLogger(__name__)
# Block privilege escalation and limit PIDs.
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
_SECURITY_ARGS = [
"--cap-drop", "ALL",
"--cap-add", "DAC_OVERRIDE",
"--cap-add", "CHOWN",
"--cap-add", "FOWNER",
"--security-opt", "no-new-privileges",
"--pids-limit", "256",
"--tmpfs", "/tmp:rw,nosuid,size=512m",
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
"--cap-drop",
"ALL",
"--cap-add",
"DAC_OVERRIDE",
"--cap-add",
"CHOWN",
"--cap-add",
"FOWNER",
"--security-opt",
"no-new-privileges",
"--pids-limit",
"256",
"--tmpfs",
"/tmp:rw,nosuid,size=512m",
"--tmpfs",
"/var/tmp:rw,noexec,nosuid,size=256m",
"--tmpfs",
"/run:rw,noexec,nosuid,size=64m",
]
_storage_opt_ok: Optional[bool] = None # cached result across instances
_storage_opt_ok: bool | None = None # cached result across instances
class DockerEnvironment(BaseEnvironment):
@ -74,7 +81,7 @@ class DockerEnvironment(BaseEnvironment):
self._base_image = image
self._persistent = persistent_filesystem
self._task_id = task_id
self._container_id: Optional[str] = None
self._container_id: str | None = None
logger.info(f"DockerEnvironment volumes: {volumes}")
# Ensure volumes is a list (config.yaml could be malformed)
if volumes is not None and not isinstance(volumes, list):
@ -105,8 +112,8 @@ class DockerEnvironment(BaseEnvironment):
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
from tools.environments.base import get_sandbox_dir
self._workspace_dir: Optional[str] = None
self._home_dir: Optional[str] = None
self._workspace_dir: str | None = None
self._home_dir: str | None = None
if self._persistent:
sandbox = get_sandbox_dir() / "docker" / task_id
self._workspace_dir = str(sandbox / "workspace")
@ -114,14 +121,19 @@ class DockerEnvironment(BaseEnvironment):
os.makedirs(self._workspace_dir, exist_ok=True)
os.makedirs(self._home_dir, exist_ok=True)
writable_args = [
"-v", f"{self._workspace_dir}:/workspace",
"-v", f"{self._home_dir}:/root",
"-v",
f"{self._workspace_dir}:/workspace",
"-v",
f"{self._home_dir}:/root",
]
else:
writable_args = [
"--tmpfs", "/workspace:rw,exec,size=10g",
"--tmpfs", "/home:rw,exec,size=1g",
"--tmpfs", "/root:rw,exec,size=1g",
"--tmpfs",
"/workspace:rw,exec,size=10g",
"--tmpfs",
"/home:rw,exec,size=1g",
"--tmpfs",
"/root:rw,exec,size=1g",
]
# All containers get security hardening (capabilities dropped, no privilege
@ -129,7 +141,7 @@ class DockerEnvironment(BaseEnvironment):
# can install packages as needed.
# User-configured volume mounts (from config.yaml docker_volumes)
volume_args = []
for vol in (volumes or []):
for vol in volumes or []:
if not isinstance(vol, str):
logger.warning(f"Docker volume entry is not a string: {vol!r}")
continue
@ -146,7 +158,9 @@ class DockerEnvironment(BaseEnvironment):
logger.info(f"Docker run_args: {all_run_args}")
self._inner = _Docker(
image=image, cwd=cwd, timeout=timeout,
image=image,
cwd=cwd,
timeout=timeout,
run_args=all_run_args,
)
self._container_id = self._inner.container_id
@ -154,7 +168,7 @@ class DockerEnvironment(BaseEnvironment):
@staticmethod
def _storage_opt_supported() -> bool:
"""Check if Docker's storage driver supports --storage-opt size=.
Only overlay2 on XFS with pquota supports per-container disk quotas.
Ubuntu (and most distros) default to ext4, where this flag errors out.
"""
@ -164,7 +178,9 @@ class DockerEnvironment(BaseEnvironment):
try:
result = subprocess.run(
["docker", "info", "--format", "{{.Driver}}"],
capture_output=True, text=True, timeout=10,
capture_output=True,
text=True,
timeout=10,
)
driver = result.stdout.strip().lower()
if driver != "overlay2":
@ -174,14 +190,15 @@ class DockerEnvironment(BaseEnvironment):
# Probe by attempting a dry-ish run — the fastest reliable check.
probe = subprocess.run(
["docker", "create", "--storage-opt", "size=1m", "hello-world"],
capture_output=True, text=True, timeout=15,
capture_output=True,
text=True,
timeout=15,
)
if probe.returncode == 0:
# Clean up the created container
container_id = probe.stdout.strip()
if container_id:
subprocess.run(["docker", "rm", container_id],
capture_output=True, timeout=5)
subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5)
_storage_opt_ok = True
else:
_storage_opt_ok = False
@ -190,9 +207,9 @@ class DockerEnvironment(BaseEnvironment):
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
return _storage_opt_ok
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
exec_command = self._prepare_command(command)
work_dir = cwd or self.cwd
effective_timeout = timeout or self.timeout
@ -218,7 +235,8 @@ class DockerEnvironment(BaseEnvironment):
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
text=True,
)
@ -269,6 +287,7 @@ class DockerEnvironment(BaseEnvironment):
if not self._persistent:
import shutil
for d in (self._workspace_dir, self._home_dir):
if d:
shutil.rmtree(d, ignore_errors=True)

View file

@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment):
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
from tools.terminal_tool import _interrupt_event
work_dir = cwd or self.cwd or os.getcwd()
@ -172,11 +172,7 @@ class LocalEnvironment(BaseEnvironment):
# Wrap with output fences so we can later extract the real
# command output and discard shell init/exit noise.
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
f" __hermes_rc=$?;"
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc"
)
# Ensure PATH always includes standard dirs — systemd services
# and some terminal multiplexers inherit a minimal PATH.
@ -200,12 +196,14 @@ class LocalEnvironment(BaseEnvironment):
)
if stdin_data is not None:
def _write_stdin():
try:
proc.stdin.write(stdin_data)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []

View file

@ -8,10 +8,9 @@ project files, and config changes survive across sessions.
import json
import logging
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
def _load_snapshots() -> Dict[str, str]:
def _load_snapshots() -> dict[str, str]:
"""Load snapshot ID mapping from disk."""
if _SNAPSHOT_STORE.exists():
try:
@ -31,7 +30,7 @@ def _load_snapshots() -> Dict[str, str]:
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
def _save_snapshots(data: dict[str, str]) -> None:
"""Persist snapshot ID mapping to disk."""
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
@ -52,7 +51,7 @@ class ModalEnvironment(BaseEnvironment):
image: str,
cwd: str = "~",
timeout: int = 60,
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
modal_sandbox_kwargs: dict[str, Any] | None = None,
persistent_filesystem: bool = True,
task_id: str = "default",
):
@ -61,6 +60,7 @@ class ModalEnvironment(BaseEnvironment):
if not ModalEnvironment._patches_applied:
try:
from environments.patches import apply_patches
apply_patches()
except ImportError:
pass
@ -79,6 +79,7 @@ class ModalEnvironment(BaseEnvironment):
if snapshot_id:
try:
import modal
restored_image = modal.Image.from_id(snapshot_id)
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
except Exception as e:
@ -88,6 +89,7 @@ class ModalEnvironment(BaseEnvironment):
effective_image = restored_image if restored_image else image
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
self._inner = SwerexModalEnvironment(
image=effective_image,
cwd=cwd,
@ -97,9 +99,9 @@ class ModalEnvironment(BaseEnvironment):
modal_sandbox_kwargs=sandbox_kwargs,
)
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
if stdin_data is not None:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
while marker in stdin_data:
@ -139,29 +141,29 @@ class ModalEnvironment(BaseEnvironment):
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
if self._persistent:
try:
sandbox = getattr(self._inner, 'deployment', None)
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
sandbox = getattr(self._inner, "deployment", None)
sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None
if sandbox:
import asyncio
async def _snapshot():
img = await sandbox.snapshot_filesystem.aio()
return img.object_id
try:
snapshot_id = asyncio.run(_snapshot())
except RuntimeError:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
snapshot_id = pool.submit(
asyncio.run, _snapshot()
).result(timeout=60)
snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60)
snapshots = _load_snapshots()
snapshots[self._task_id] = snapshot_id
_save_snapshots(snapshots)
logger.info("Modal: saved filesystem snapshot %s for task %s",
snapshot_id[:20], self._task_id)
logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id)
except Exception as e:
logger.warning("Modal: filesystem snapshot failed: %s", e)
if hasattr(self._inner, 'stop'):
if hasattr(self._inner, "stop"):
self._inner.stop()

View file

@ -10,11 +10,9 @@ import logging
import os
import shutil
import subprocess
import tempfile
import threading
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
def _load_snapshots() -> Dict[str, str]:
def _load_snapshots() -> dict[str, str]:
if _SNAPSHOT_STORE.exists():
try:
return json.loads(_SNAPSHOT_STORE.read_text())
@ -33,7 +31,7 @@ def _load_snapshots() -> Dict[str, str]:
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
def _save_snapshots(data: dict[str, str]) -> None:
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
@ -42,6 +40,7 @@ def _save_snapshots(data: Dict[str, str]) -> None:
# Singularity helpers (scratch dir, SIF cache, SIF building)
# -------------------------------------------------------------------------
def _get_scratch_dir() -> Path:
"""Get the best directory for Singularity sandboxes.
@ -58,6 +57,7 @@ def _get_scratch_dir() -> Path:
return scratch_path
from tools.environments.base import get_sandbox_dir
sandbox = get_sandbox_dir() / "singularity"
scratch = Path("/scratch")
@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
Returns the path unchanged if it's already a .sif file.
For docker:// URLs, checks the cache and builds if needed.
"""
if image.endswith('.sif') and Path(image).exists():
if image.endswith(".sif") and Path(image).exists():
return image
if not image.startswith('docker://'):
if not image.startswith("docker://"):
return image
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-")
cache_dir = _get_apptainer_cache_dir()
sif_path = cache_dir / f"{image_name}.sif"
@ -123,7 +123,10 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
try:
result = subprocess.run(
[executable, "build", str(sif_path), image],
capture_output=True, text=True, timeout=600, env=env,
capture_output=True,
text=True,
timeout=600,
env=env,
)
if result.returncode != 0:
logger.warning("SIF build failed, falling back to docker:// URL")
@ -145,6 +148,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
# SingularityEnvironment
# -------------------------------------------------------------------------
class SingularityEnvironment(BaseEnvironment):
"""Hardened Singularity/Apptainer container with resource limits and persistence.
@ -174,7 +178,7 @@ class SingularityEnvironment(BaseEnvironment):
self._instance_started = False
self._persistent = persistent_filesystem
self._task_id = task_id
self._overlay_dir: Optional[Path] = None
self._overlay_dir: Path | None = None
# Resource limits
self._cpu = cpu
@ -215,14 +219,13 @@ class SingularityEnvironment(BaseEnvironment):
if result.returncode != 0:
raise RuntimeError(f"Failed to start instance: {result.stderr}")
self._instance_started = True
logger.info("Singularity instance %s started (persistent=%s)",
self.instance_id, self._persistent)
logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent)
except subprocess.TimeoutExpired:
raise RuntimeError("Instance start timed out")
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
if not self._instance_started:
return {"output": "Instance not started", "returncode": -1}
@ -235,16 +238,16 @@ class SingularityEnvironment(BaseEnvironment):
exec_command = f"cd {work_dir} && {exec_command}"
work_dir = "/tmp"
cmd = [self.executable, "exec", "--pwd", work_dir,
f"instance://{self.instance_id}",
"bash", "-c", exec_command]
cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command]
try:
import time as _time
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
text=True,
)
@ -295,7 +298,9 @@ class SingularityEnvironment(BaseEnvironment):
try:
subprocess.run(
[self.executable, "instance", "stop", self.instance_id],
capture_output=True, text=True, timeout=30,
capture_output=True,
text=True,
timeout=30,
)
logger.info("Singularity instance %s stopped", self.instance_id)
except Exception as e:

View file

@ -24,8 +24,7 @@ class SSHEnvironment(BaseEnvironment):
and a remote kill is attempted over the ControlMaster socket.
"""
def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = ""):
def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""):
super().__init__(cwd=cwd, timeout=timeout)
self.host = host
self.user = user
@ -65,12 +64,12 @@ class SSHEnvironment(BaseEnvironment):
except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
work_dir = cwd or self.cwd
exec_command = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}'
wrapped = f"cd {work_dir} && {exec_command}"
effective_timeout = timeout or self.timeout
cmd = self._build_ssh_command()
@ -136,8 +135,7 @@ class SSHEnvironment(BaseEnvironment):
def cleanup(self):
if self.control_socket.exists():
try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
"-O", "exit", f"{self.user}@{self.host}"]
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"]
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass

File diff suppressed because it is too large Load diff

View file

@ -3,11 +3,10 @@
import json
import logging
import os
import threading
from typing import Optional
from tools.file_operations import ShellFileOperations
from agent.redact import redact_sensitive_text
from tools.file_operations import ShellFileOperations
logger = logging.getLogger(__name__)
@ -25,14 +24,19 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
Thread-safe: uses the same per-task creation locks as terminal_tool to
prevent duplicate sandbox creation from concurrent tool calls.
"""
from tools.terminal_tool import (
_active_environments, _env_lock, _create_environment,
_get_env_config, _last_activity, _start_cleanup_thread,
_check_disk_usage_warning,
_creation_locks, _creation_locks_lock,
)
import time
from tools.terminal_tool import (
_active_environments,
_create_environment,
_creation_locks,
_creation_locks_lock,
_env_lock,
_get_env_config,
_last_activity,
_start_cleanup_thread,
)
# Fast path: check cache -- but also verify the underlying environment
# is still alive (it may have been killed by the cleanup thread).
with _file_ops_lock:
@ -143,17 +147,23 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
result = file_ops.write_file(path, content)
return json.dumps(result.to_dict(), ensure_ascii=False)
except Exception as e:
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
return json.dumps({"error": str(e)}, ensure_ascii=False)
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
new_string: str = None, replace_all: bool = False, patch: str = None,
task_id: str = "default") -> str:
def patch_tool(
mode: str = "replace",
path: str = None,
old_string: str = None,
new_string: str = None,
replace_all: bool = False,
patch: str = None,
task_id: str = "default",
) -> str:
"""Patch a file using replace mode or V4A patch format."""
try:
file_ops = _get_file_ops(task_id)
if mode == "replace":
if not path:
return json.dumps({"error": "path required"})
@ -166,7 +176,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
result = file_ops.patch_v4a(patch)
else:
return json.dumps({"error": f"Unknown mode: {mode}"})
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when old_string not found — saves iterations where the agent
@ -178,20 +188,33 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
return json.dumps({"error": str(e)}, ensure_ascii=False)
def search_tool(pattern: str, target: str = "content", path: str = ".",
file_glob: str = None, limit: int = 50, offset: int = 0,
output_mode: str = "content", context: int = 0,
task_id: str = "default") -> str:
def search_tool(
pattern: str,
target: str = "content",
path: str = ".",
file_glob: str = None,
limit: int = 50,
offset: int = 0,
output_mode: str = "content",
context: int = 0,
task_id: str = "default",
) -> str:
"""Search for content or files."""
try:
file_ops = _get_file_ops(task_id)
result = file_ops.search(
pattern=pattern, path=path, target=target, file_glob=file_glob,
limit=limit, offset=offset, output_mode=output_mode, context=context
pattern=pattern,
path=path,
target=target,
file_glob=file_glob,
limit=limit,
offset=offset,
output_mode=output_mode,
context=context,
)
if hasattr(result, 'matches'):
if hasattr(result, "matches"):
for m in result.matches:
if hasattr(m, 'content') and m.content:
if hasattr(m, "content") and m.content:
m.content = redact_sensitive_text(m.content)
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
@ -209,7 +232,7 @@ FILE_TOOLS = [
{"name": "read_file", "function": read_file_tool},
{"name": "write_file", "function": write_file_tool},
{"name": "patch", "function": patch_tool},
{"name": "search_files", "function": search_tool}
{"name": "search_files", "function": search_tool},
]
@ -227,8 +250,10 @@ from tools.registry import registry
def _check_file_reqs():
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
from tools import check_file_requirements
return check_file_requirements()
READ_FILE_SCHEMA = {
"name": "read_file",
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
@ -236,11 +261,21 @@ READ_FILE_SCHEMA = {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, default: 1)",
"default": 1,
"minimum": 1,
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (default: 500, max: 2000)",
"default": 500,
"maximum": 2000,
},
},
"required": ["path"]
}
"required": ["path"],
},
}
WRITE_FILE_SCHEMA = {
@ -249,11 +284,14 @@ WRITE_FILE_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
"content": {"type": "string", "description": "Complete content to write to the file"}
"path": {
"type": "string",
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)",
},
"content": {"type": "string", "description": "Complete content to write to the file"},
},
"required": ["path", "content"]
}
"required": ["path", "content"],
},
}
PATCH_SCHEMA = {
@ -262,15 +300,33 @@ PATCH_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
"mode": {
"type": "string",
"enum": ["replace", "patch"],
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
"default": "replace",
},
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
"old_string": {
"type": "string",
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.",
},
"new_string": {
"type": "string",
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.",
},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
"default": False,
},
"patch": {
"type": "string",
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch",
},
},
"required": ["mode"]
}
"required": ["mode"],
},
}
SEARCH_FILES_SCHEMA = {
@ -279,23 +335,57 @@ SEARCH_FILES_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
"pattern": {
"type": "string",
"description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search",
},
"target": {
"type": "string",
"enum": ["content", "files"],
"description": "'content' searches inside file contents, 'files' searches for files by name",
"default": "content",
},
"path": {
"type": "string",
"description": "Directory or file to search in (default: current working directory)",
"default": ".",
},
"file_glob": {
"type": "string",
"description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)",
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return (default: 50)",
"default": 50,
},
"offset": {
"type": "integer",
"description": "Skip first N results for pagination (default: 0)",
"default": 0,
},
"output_mode": {
"type": "string",
"enum": ["content", "files_only", "count"],
"description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
"default": "content",
},
"context": {
"type": "integer",
"description": "Number of context lines before and after each match (grep mode only)",
"default": 0,
},
},
"required": ["pattern"]
}
"required": ["pattern"],
},
}
def _handle_read_file(args, **kw):
tid = kw.get("task_id") or "default"
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
return read_file_tool(
path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid
)
def _handle_write_file(args, **kw):
@ -306,9 +396,14 @@ def _handle_write_file(args, **kw):
def _handle_patch(args, **kw):
tid = kw.get("task_id") or "default"
return patch_tool(
mode=args.get("mode", "replace"), path=args.get("path"),
old_string=args.get("old_string"), new_string=args.get("new_string"),
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
mode=args.get("mode", "replace"),
path=args.get("path"),
old_string=args.get("old_string"),
new_string=args.get("new_string"),
replace_all=args.get("replace_all", False),
patch=args.get("patch"),
task_id=tid,
)
def _handle_search_files(args, **kw):
@ -317,12 +412,29 @@ def _handle_search_files(args, **kw):
raw_target = args.get("target", "content")
target = target_map.get(raw_target, raw_target)
return search_tool(
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
pattern=args.get("pattern", ""),
target=target,
path=args.get("path", "."),
file_glob=args.get("file_glob"),
limit=args.get("limit", 50),
offset=args.get("offset", 0),
output_mode=args.get("output_mode", "content"),
context=args.get("context", 0),
task_id=tid,
)
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs)
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs)
registry.register(
name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs
)
registry.register(
name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs
)
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs)
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs)
registry.register(
name="search_files",
toolset="file",
schema=SEARCH_FILES_SCHEMA,
handler=_handle_search_files,
check_fn=_check_file_reqs,
)

View file

@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode):
Usage:
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, match_count, error = fuzzy_find_and_replace(
content="def foo():\\n pass",
old_string="def foo():",
@ -29,21 +29,22 @@ Usage:
"""
import re
from typing import Tuple, Optional, List, Callable
from collections.abc import Callable
from difflib import SequenceMatcher
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
def fuzzy_find_and_replace(
content: str, old_string: str, new_string: str, replace_all: bool = False
) -> tuple[str, int, str | None]:
"""
Find and replace text using a chain of increasingly fuzzy matching strategies.
Args:
content: The file content to search in
old_string: The text to find
new_string: The replacement text
replace_all: If True, replace all occurrences; if False, require uniqueness
Returns:
Tuple of (new_content, match_count, error_message)
- If successful: (modified_content, number_of_replacements, None)
@ -51,12 +52,12 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
"""
if not old_string:
return content, 0, "old_string cannot be empty"
if old_string == new_string:
return content, 0, "old_string and new_string are identical"
# Try each matching strategy in order
strategies: List[Tuple[str, Callable]] = [
strategies: list[tuple[str, Callable]] = [
("exact", _strategy_exact),
("line_trimmed", _strategy_line_trimmed),
("whitespace_normalized", _strategy_whitespace_normalized),
@ -66,46 +67,50 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
("block_anchor", _strategy_block_anchor),
("context_aware", _strategy_context_aware),
]
for strategy_name, strategy_fn in strategies:
matches = strategy_fn(content, old_string)
if matches:
# Found matches with this strategy
if len(matches) > 1 and not replace_all:
return content, 0, (
f"Found {len(matches)} matches for old_string. "
f"Provide more context to make it unique, or use replace_all=True."
return (
content,
0,
(
f"Found {len(matches)} matches for old_string. "
f"Provide more context to make it unique, or use replace_all=True."
),
)
# Perform replacement
new_content = _apply_replacements(content, matches, new_string)
return new_content, len(matches), None
# No strategy found a match
return content, 0, "Could not find a match for old_string in the file"
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str:
"""
Apply replacements at the given positions.
Args:
content: Original content
matches: List of (start, end) positions to replace
new_string: Replacement text
Returns:
Content with replacements applied
"""
# Sort matches by position (descending) to replace from end to start
# This preserves positions of earlier matches
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
result = content
for start, end in sorted_matches:
result = result[:start] + new_string + result[end:]
return result
@ -113,7 +118,8 @@ def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string
# Matching Strategies
# =============================================================================
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
"""Strategy 1: Exact string match."""
matches = []
start = 0
@ -126,206 +132,201 @@ def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
return matches
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 2: Match with line-by-line whitespace trimming.
Strips leading/trailing whitespace from each line before matching.
"""
# Normalize pattern and content by trimming each line
pattern_lines = [line.strip() for line in pattern.split('\n')]
pattern_normalized = '\n'.join(pattern_lines)
content_lines = content.split('\n')
pattern_lines = [line.strip() for line in pattern.split("\n")]
pattern_normalized = "\n".join(pattern_lines)
content_lines = content.split("\n")
content_normalized_lines = [line.strip() for line in content_lines]
# Build mapping from normalized positions back to original positions
return _find_normalized_matches(
content, content_lines, content_normalized_lines,
pattern, pattern_normalized
)
return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized)
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 3: Collapse multiple whitespace to single space.
"""
def normalize(s):
# Collapse multiple spaces/tabs to single space, preserve newlines
return re.sub(r'[ \t]+', ' ', s)
return re.sub(r"[ \t]+", " ", s)
pattern_normalized = normalize(pattern)
content_normalized = normalize(content)
# Find in normalized, map back to original
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
if not matches_in_normalized:
return []
# Map positions back to original content
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 4: Ignore indentation differences entirely.
Strips all leading whitespace from lines before matching.
"""
def strip_indent(s):
return '\n'.join(line.lstrip() for line in s.split('\n'))
return "\n".join(line.lstrip() for line in s.split("\n"))
pattern_stripped = strip_indent(pattern)
content_lines = content.split('\n')
content_lines = content.split("\n")
content_stripped_lines = [line.lstrip() for line in content_lines]
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
return _find_normalized_matches(
content, content_lines, content_stripped_lines,
pattern, '\n'.join(pattern_lines)
)
pattern_lines = [line.lstrip() for line in pattern.split("\n")]
return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines))
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 5: Convert escape sequences to actual characters.
Handles \\n -> newline, \\t -> tab, etc.
"""
def unescape(s):
# Convert common escape sequences
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
pattern_unescaped = unescape(pattern)
if pattern_unescaped == pattern:
# No escapes to convert, skip this strategy
return []
return _strategy_exact(content, pattern_unescaped)
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 6: Trim whitespace from first and last lines only.
Useful when the pattern boundaries have whitespace differences.
"""
pattern_lines = pattern.split('\n')
pattern_lines = pattern.split("\n")
if not pattern_lines:
return []
# Trim only first and last lines
pattern_lines[0] = pattern_lines[0].strip()
if len(pattern_lines) > 1:
pattern_lines[-1] = pattern_lines[-1].strip()
modified_pattern = '\n'.join(pattern_lines)
content_lines = content.split('\n')
modified_pattern = "\n".join(pattern_lines)
content_lines = content.split("\n")
# Search through content for matching block
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
block_lines = content_lines[i : i + pattern_line_count]
# Trim first and last of this block
check_lines = block_lines.copy()
check_lines[0] = check_lines[0].strip()
if len(check_lines) > 1:
check_lines[-1] = check_lines[-1].strip()
if '\n'.join(check_lines) == modified_pattern:
if "\n".join(check_lines) == modified_pattern:
# Found match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 7: Match by anchoring on first and last lines.
If first and last lines match exactly, accept middle with 70% similarity.
"""
pattern_lines = pattern.split('\n')
pattern_lines = pattern.split("\n")
if len(pattern_lines) < 2:
return [] # Need at least 2 lines for anchoring
first_line = pattern_lines[0].strip()
last_line = pattern_lines[-1].strip()
content_lines = content.split('\n')
content_lines = content.split("\n")
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
# Check if first and last lines match
if (content_lines[i].strip() == first_line and
content_lines[i + pattern_line_count - 1].strip() == last_line):
if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line:
# Check middle similarity
if pattern_line_count <= 2:
# Only first and last, they match
similarity = 1.0
else:
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
pattern_middle = '\n'.join(pattern_lines[1:-1])
content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1])
pattern_middle = "\n".join(pattern_lines[1:-1])
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
if similarity >= 0.70:
# Calculate positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 8: Line-by-line similarity with 50% threshold.
Finds blocks where at least 50% of lines have high similarity.
"""
pattern_lines = pattern.split('\n')
content_lines = content.split('\n')
pattern_lines = pattern.split("\n")
content_lines = content.split("\n")
if not pattern_lines:
return []
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
block_lines = content_lines[i : i + pattern_line_count]
# Calculate line-by-line similarity
high_similarity_count = 0
for p_line, c_line in zip(pattern_lines, block_lines):
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
if sim >= 0.80:
high_similarity_count += 1
# Need at least 50% of lines to have high similarity
if high_similarity_count >= len(pattern_lines) * 0.5:
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
@ -333,74 +334,76 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
# Helper Functions
# =============================================================================
def _find_normalized_matches(content: str, content_lines: List[str],
content_normalized_lines: List[str],
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
def _find_normalized_matches(
content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str
) -> list[tuple[int, int]]:
"""
Find matches in normalized content and map back to original positions.
Args:
content: Original content string
content_lines: Original content split by lines
content_normalized_lines: Normalized content lines
pattern: Original pattern
pattern_normalized: Normalized pattern
Returns:
List of (start, end) positions in the original content
"""
pattern_norm_lines = pattern_normalized.split('\n')
pattern_norm_lines = pattern_normalized.split("\n")
num_pattern_lines = len(pattern_norm_lines)
matches = []
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
# Check if this block matches
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
block = "\n".join(content_normalized_lines[i : i + num_pattern_lines])
if block == pattern_normalized:
# Found a match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1
# Handle case where end is past content
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _map_normalized_positions(original: str, normalized: str,
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
def _map_normalized_positions(
original: str, normalized: str, normalized_matches: list[tuple[int, int]]
) -> list[tuple[int, int]]:
"""
Map positions from normalized string back to original.
This is a best-effort mapping that works for whitespace normalization.
"""
if not normalized_matches:
return []
# Build character mapping from normalized to original
orig_to_norm = [] # orig_to_norm[i] = position in normalized
orig_idx = 0
norm_idx = 0
while orig_idx < len(original) and norm_idx < len(normalized):
if original[orig_idx] == normalized[norm_idx]:
orig_to_norm.append(norm_idx)
orig_idx += 1
norm_idx += 1
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
elif original[orig_idx] in " \t" and normalized[norm_idx] == " ":
# Original has space/tab, normalized collapsed to space
orig_to_norm.append(norm_idx)
orig_idx += 1
# Don't advance norm_idx yet - wait until all whitespace consumed
if orig_idx < len(original) and original[orig_idx] not in ' \t':
if orig_idx < len(original) and original[orig_idx] not in " \t":
norm_idx += 1
elif original[orig_idx] in ' \t':
elif original[orig_idx] in " \t":
# Extra whitespace in original
orig_to_norm.append(norm_idx)
orig_idx += 1
@ -408,21 +411,21 @@ def _map_normalized_positions(original: str, normalized: str,
# Mismatch - shouldn't happen with our normalization
orig_to_norm.append(norm_idx)
orig_idx += 1
# Fill remaining
while orig_idx < len(original):
orig_to_norm.append(len(normalized))
orig_idx += 1
# Reverse mapping: for each normalized position, find original range
norm_to_orig_start = {}
norm_to_orig_end = {}
for orig_pos, norm_pos in enumerate(orig_to_norm):
if norm_pos not in norm_to_orig_start:
norm_to_orig_start[norm_pos] = orig_pos
norm_to_orig_end[norm_pos] = orig_pos
# Map matches
original_matches = []
for norm_start, norm_end in normalized_matches:
@ -432,17 +435,17 @@ def _map_normalized_positions(original: str, normalized: str,
else:
# Find nearest
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
# Find original end
if norm_end - 1 in norm_to_orig_end:
orig_end = norm_to_orig_end[norm_end - 1] + 1
else:
orig_end = orig_start + (norm_end - norm_start)
# Expand to include trailing whitespace that was normalized
while orig_end < len(original) and original[orig_end] in ' \t':
while orig_end < len(original) and original[orig_end] in " \t":
orig_end += 1
original_matches.append((orig_start, min(orig_end, len(original))))
return original_matches

View file

@ -15,7 +15,7 @@ import json
import logging
import os
import re
from typing import Any, Dict, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -35,23 +35,26 @@ def _get_config():
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
)
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
# Service domains blocked for security -- these allow arbitrary code/command
# execution on the HA host or enable SSRF attacks on the local network.
# HA provides zero service-level access control; all safety must be in our layer.
_BLOCKED_DOMAINS = frozenset({
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
})
_BLOCKED_DOMAINS = frozenset(
{
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
}
)
def _get_headers(token: str = "") -> Dict[str, str]:
def _get_headers(token: str = "") -> dict[str, str]:
"""Return authorization headers for HA REST API."""
if not token:
_, token = _get_config()
@ -65,11 +68,12 @@ def _get_headers(token: str = "") -> Dict[str, str]:
# Async helpers (called from sync handlers via run_until_complete)
# ---------------------------------------------------------------------------
def _filter_and_summarize(
states: list,
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
domain: str | None = None,
area: str | None = None,
) -> dict[str, Any]:
"""Filter raw HA states by domain/area and return a compact summary."""
if domain:
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
@ -77,26 +81,29 @@ def _filter_and_summarize(
if area:
area_lower = area.lower()
states = [
s for s in states
s
for s in states
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
]
entities = []
for s in states:
entities.append({
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
})
entities.append(
{
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
}
)
return {"count": len(entities), "entities": entities}
async def _async_list_entities(
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
domain: str | None = None,
area: str | None = None,
) -> dict[str, Any]:
"""Fetch entity states from HA and optionally filter by domain/area."""
import aiohttp
@ -110,7 +117,7 @@ async def _async_list_entities(
return _filter_and_summarize(states, domain, area)
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
async def _async_get_state(entity_id: str) -> dict[str, Any]:
"""Fetch detailed state of a single entity."""
import aiohttp
@ -131,11 +138,11 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]:
def _build_service_payload(
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
entity_id: str | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build the JSON payload for a HA service call."""
payload: Dict[str, Any] = {}
payload: dict[str, Any] = {}
if data:
payload.update(data)
# entity_id parameter takes precedence over data["entity_id"]
@ -148,15 +155,17 @@ def _parse_service_response(
domain: str,
service: str,
result: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Parse HA service call response into a structured result."""
affected = []
if isinstance(result, list):
for s in result:
affected.append({
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
})
affected.append(
{
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
}
)
return {
"success": True,
@ -168,9 +177,9 @@ def _parse_service_response(
async def _async_call_service(
domain: str,
service: str,
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
entity_id: str | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Call a Home Assistant service."""
import aiohttp
@ -178,15 +187,17 @@ async def _async_call_service(
url = f"{hass_url}/api/services/{domain}/{service}"
payload = _build_service_payload(entity_id, data)
async with aiohttp.ClientSession() as session:
async with session.post(
async with (
aiohttp.ClientSession() as session,
session.post(
url,
headers=_get_headers(hass_token),
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
resp.raise_for_status()
result = await resp.json()
) as resp,
):
resp.raise_for_status()
result = await resp.json()
return _parse_service_response(domain, service, result)
@ -195,6 +206,7 @@ async def _async_call_service(
# Sync wrappers (handler signature: (args, **kw) -> str)
# ---------------------------------------------------------------------------
def _run_async(coro):
"""Run an async coroutine from a sync handler."""
try:
@ -205,6 +217,7 @@ def _run_async(coro):
if loop and loop.is_running():
# Already inside an event loop -- create a new thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, coro)
return future.result(timeout=30)
@ -247,10 +260,12 @@ def _handle_call_service(args: dict, **kw) -> str:
return json.dumps({"error": "Missing required parameters: domain and service"})
if domain in _BLOCKED_DOMAINS:
return json.dumps({
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
})
return json.dumps(
{
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
}
)
entity_id = args.get("entity_id")
if entity_id and not _ENTITY_ID_RE.match(entity_id):
@ -269,7 +284,8 @@ def _handle_call_service(args: dict, **kw) -> str:
# List services
# ---------------------------------------------------------------------------
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
"""Fetch available services from HA and optionally filter by domain."""
import aiohttp
@ -290,13 +306,10 @@ async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
d = svc_domain.get("domain", "")
domain_services = {}
for svc_name, svc_info in svc_domain.get("services", {}).items():
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")}
fields = svc_info.get("fields", {})
if fields:
svc_entry["fields"] = {
k: v.get("description", "") for k, v in fields.items()
if isinstance(v, dict)
}
svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)}
domain_services[svc_name] = svc_entry
result.append({"domain": d, "services": domain_services})
@ -318,6 +331,7 @@ def _handle_list_services(args: dict, **kw) -> str:
# Availability check
# ---------------------------------------------------------------------------
def _check_ha_available() -> bool:
"""Tool is only available when HASS_TOKEN is set."""
return bool(os.getenv("HASS_TOKEN"))
@ -369,8 +383,7 @@ HA_GET_STATE_SCHEMA = {
"entity_id": {
"type": "string",
"description": (
"The entity ID to query (e.g. 'light.living_room', "
"'climate.thermostat', 'sensor.temperature')."
"The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')."
),
},
},
@ -392,8 +405,7 @@ HA_LIST_SERVICES_SCHEMA = {
"domain": {
"type": "string",
"description": (
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
"Omit to list services for all domains."
"Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains."
),
},
},
@ -428,8 +440,7 @@ HA_CALL_SERVICE_SCHEMA = {
"entity_id": {
"type": "string",
"description": (
"Target entity ID (e.g. 'light.living_room'). "
"Some services (like scene.turn_on) may not need this."
"Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this."
),
},
"data": {

View file

@ -65,6 +65,7 @@ HONCHO_TOOL_SCHEMA = {
# ── Tool handler ──
def _handle_query_user_context(args: dict, **kw) -> str:
"""Execute the Honcho context query."""
query = args.get("query", "")
@ -84,6 +85,7 @@ def _handle_query_user_context(args: dict, **kw) -> str:
# ── Availability check ──
def _check_honcho_available() -> bool:
"""Tool is only available when Honcho is active."""
return _session_manager is not None and _session_key is not None

View file

@ -2,7 +2,7 @@
"""
Image Generation Tools Module
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
Available tools:
@ -19,7 +19,7 @@ Features:
Usage:
from image_generation_tool import image_generate_tool
import asyncio
# Generate and automatically upscale an image
result = await image_generate_tool(
prompt="A serene mountain landscape with cherry blossoms",
@ -28,12 +28,14 @@ Usage:
)
"""
import datetime
import json
import logging
import os
import datetime
from typing import Dict, Any, Optional, Union
from typing import Any
import fal_client
from tools.debug_helpers import DebugSession
logger = logging.getLogger(__name__)
@ -51,11 +53,7 @@ ENABLE_SAFETY_CHECKER = False
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
# Aspect ratio mapping - simplified choices for model to select
ASPECT_RATIO_MAP = {
"landscape": "landscape_16_9",
"square": "square_hd",
"portrait": "portrait_16_9"
}
ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"}
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
# Configuration for automatic upscaling
@ -70,9 +68,7 @@ UPSCALER_GUIDANCE_SCALE = 4
UPSCALER_NUM_INFERENCE_STEPS = 18
# Valid parameter values for validation based on FLUX 2 Pro documentation
VALID_IMAGE_SIZES = [
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
]
VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
@ -80,16 +76,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
def _validate_parameters(
image_size: Union[str, Dict[str, int]],
image_size: str | dict[str, int],
num_inference_steps: int,
guidance_scale: float,
num_images: int,
output_format: str,
acceleration: str = "none"
) -> Dict[str, Any]:
acceleration: str = "none",
) -> dict[str, Any]:
"""
Validate and normalize image generation parameters for FLUX 2 Pro model.
Args:
image_size: Either a preset string or custom size dict
num_inference_steps: Number of inference steps
@ -97,15 +93,15 @@ def _validate_parameters(
num_images: Number of images to generate
output_format: Output format for images
acceleration: Acceleration mode for generation speed
Returns:
Dict[str, Any]: Validated and normalized parameters
Raises:
ValueError: If any parameter is invalid
"""
validated = {}
# Validate image_size
if isinstance(image_size, str):
if image_size not in VALID_IMAGE_SIZES:
@ -123,52 +119,52 @@ def _validate_parameters(
validated["image_size"] = image_size
else:
raise ValueError("image_size must be either a preset string or a dict with width/height")
# Validate num_inference_steps
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
raise ValueError("num_inference_steps must be an integer between 1 and 100")
validated["num_inference_steps"] = num_inference_steps
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
validated["guidance_scale"] = float(guidance_scale)
# Validate num_images
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
raise ValueError("num_images must be an integer between 1 and 4")
validated["num_images"] = num_images
# Validate output_format
if output_format not in VALID_OUTPUT_FORMATS:
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
validated["output_format"] = output_format
# Validate acceleration
if acceleration not in VALID_ACCELERATION_MODES:
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
validated["acceleration"] = acceleration
return validated
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
"""
Upscale an image using FAL.ai's Clarity Upscaler.
Uses the synchronous fal_client API to avoid event loop lifecycle issues
when called from threaded contexts (e.g. gateway thread pool).
Args:
image_url (str): URL of the image to upscale
original_prompt (str): Original prompt used to generate the image
Returns:
Dict[str, Any]: Upscaled image data or None if upscaling fails
"""
try:
logger.info("Upscaling image with Clarity Upscaler...")
# Prepare arguments for upscaler
upscaler_arguments = {
"image_url": image_url,
@ -179,35 +175,36 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
"resemblance": UPSCALER_RESEMBLANCE,
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
}
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
# The async API (submit_async) caches a global httpx.AsyncClient via
# @cached_property, which breaks when asyncio.run() destroys the loop
# between calls (gateway thread-pool pattern).
handler = fal_client.submit(
UPSCALER_MODEL,
arguments=upscaler_arguments
)
handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments)
# Get the upscaled result (sync — blocks until done)
result = handler.get()
if result and "image" in result:
upscaled_image = result["image"]
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
logger.info(
"Image upscaled successfully to %sx%s",
upscaled_image.get("width", "unknown"),
upscaled_image.get("height", "unknown"),
)
return {
"url": upscaled_image["url"],
"width": upscaled_image.get("width", 0),
"height": upscaled_image.get("height", 0),
"upscaled": True,
"upscale_factor": UPSCALER_FACTOR
"upscale_factor": UPSCALER_FACTOR,
}
else:
logger.error("Upscaler returned invalid response")
return None
except Exception as e:
logger.error("Error upscaling image: %s", e)
return None
@ -220,16 +217,16 @@ def image_generate_tool(
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
num_images: int = DEFAULT_NUM_IMAGES,
output_format: str = DEFAULT_OUTPUT_FORMAT,
seed: Optional[int] = None
seed: int | None = None,
) -> str:
"""
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
when asyncio.run() destroys and recreates event loops between calls, which
happens in the gateway's thread-pool pattern.
Args:
prompt (str): The text prompt describing the desired image
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
@ -238,7 +235,7 @@ def image_generate_tool(
num_images (int): Number of images to generate (1-4, default: 1)
output_format (str): Image format "jpeg" or "png" (default: "png")
seed (Optional[int]): Random seed for reproducible results (optional)
Returns:
str: JSON string containing minimal generation results:
{
@ -252,7 +249,7 @@ def image_generate_tool(
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
debug_call_data = {
"parameters": {
"prompt": prompt,
@ -262,32 +259,32 @@ def image_generate_tool(
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
"seed": seed
"seed": seed,
},
"error": None,
"success": False,
"images_generated": 0,
"generation_time": 0
"generation_time": 0,
}
start_time = datetime.datetime.now()
try:
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
# Validate prompt
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
raise ValueError("Prompt is required and must be a non-empty string")
# Check API key availability
if not os.getenv("FAL_KEY"):
raise ValueError("FAL_KEY environment variable not set")
# Validate other parameters
validated_params = _validate_parameters(
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
)
# Prepare arguments for FAL.ai FLUX 2 Pro API
arguments = {
"prompt": prompt.strip(),
@ -298,51 +295,44 @@ def image_generate_tool(
"output_format": validated_params["output_format"],
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
"safety_tolerance": SAFETY_TOLERANCE,
"sync_mode": True # Use sync mode for immediate results
"sync_mode": True, # Use sync mode for immediate results
}
# Add seed if provided
if seed is not None and isinstance(seed, int):
arguments["seed"] = seed
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
logger.info(" Model: %s", DEFAULT_MODEL)
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
logger.info(" Steps: %s", validated_params['num_inference_steps'])
logger.info(" Guidance: %s", validated_params['guidance_scale'])
logger.info(" Steps: %s", validated_params["num_inference_steps"])
logger.info(" Guidance: %s", validated_params["guidance_scale"])
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
handler = fal_client.submit(
DEFAULT_MODEL,
arguments=arguments
)
handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments)
# Get the result (sync — blocks until done)
result = handler.get()
generation_time = (datetime.datetime.now() - start_time).total_seconds()
# Process the response
if not result or "images" not in result:
raise ValueError("Invalid response from FAL.ai API - no images returned")
images = result.get("images", [])
if not images:
raise ValueError("No images were generated")
# Format image data and upscale images
formatted_images = []
for img in images:
if isinstance(img, dict) and "url" in img:
original_image = {
"url": img["url"],
"width": img.get("width", 0),
"height": img.get("height", 0)
}
original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)}
# Attempt to upscale the image
upscaled_image = _upscale_image(img["url"], prompt.strip())
if upscaled_image:
# Use upscaled image if successful
formatted_images.append(upscaled_image)
@ -351,52 +341,48 @@ def image_generate_tool(
logger.warning("Using original image as fallback")
original_image["upscaled"] = False
formatted_images.append(original_image)
if not formatted_images:
raise ValueError("No valid image URLs returned from API")
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
logger.info(
"Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count
)
# Prepare successful response - minimal format
response_data = {
"success": True,
"image": formatted_images[0]["url"] if formatted_images else None
}
response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None}
debug_call_data["success"] = True
debug_call_data["images_generated"] = len(formatted_images)
debug_call_data["generation_time"] = generation_time
# Log debug information
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
except Exception as e:
generation_time = (datetime.datetime.now() - start_time).total_seconds()
error_msg = f"Error generating image: {str(e)}"
logger.error("%s", error_msg)
# Prepare error response - minimal format
response_data = {
"success": False,
"image": None
}
response_data = {"success": False, "image": None}
debug_call_data["error"] = error_msg
debug_call_data["generation_time"] = generation_time
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
def check_fal_api_key() -> bool:
"""
Check if the FAL.ai API key is available in environment variables.
Returns:
bool: True if API key is set, False otherwise
"""
@ -406,7 +392,7 @@ def check_fal_api_key() -> bool:
def check_image_generation_requirements() -> bool:
"""
Check if all requirements for image generation tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
@ -414,19 +400,20 @@ def check_image_generation_requirements() -> bool:
# Check API key
if not check_fal_api_key():
return False
# Check if fal_client is available
import fal_client
return True
except ImportError:
return False
def get_debug_session_info() -> Dict[str, Any]:
def get_debug_session_info() -> dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
@ -439,10 +426,10 @@ if __name__ == "__main__":
"""
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
print("=" * 60)
# Check if API key is available
api_available = check_fal_api_key()
if not api_available:
print("❌ FAL_KEY environment variable not set")
print("Please set your API key: export FAL_KEY='your-key-here'")
@ -450,27 +437,28 @@ if __name__ == "__main__":
exit(1)
else:
print("✅ FAL.ai API key found")
# Check if fal_client is available
try:
import fal_client
print("✅ fal_client library available")
except ImportError:
print("❌ fal_client library not found")
print("Please install: pip install fal-client")
exit(1)
print("🛠️ Image generation tools ready for use!")
print(f"🤖 Using model: {DEFAULT_MODEL}")
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
# Show debug mode status
if _debug.active:
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
else:
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from image_generation_tool import image_generate_tool")
print(" import asyncio")
@ -484,23 +472,23 @@ if __name__ == "__main__":
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nSupported image sizes:")
for size in VALID_IMAGE_SIZES:
print(f" - {size}")
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
print("\nAcceleration modes:")
for mode in VALID_ACCELERATION_MODES:
print(f" - {mode}")
print("\nExample prompts:")
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
print(" - 'Modern architecture building with glass facade, sunset lighting'")
print(" - 'Abstract art with vibrant colors and geometric patterns'")
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
print(" - 'Futuristic cityscape with flying cars and neon lights'")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export IMAGE_TOOLS_DEBUG=true")
@ -521,17 +509,17 @@ IMAGE_GENERATE_SCHEMA = {
"properties": {
"prompt": {
"type": "string",
"description": "The text prompt describing the desired image. Be detailed and descriptive."
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
},
"aspect_ratio": {
"type": "string",
"enum": ["landscape", "square", "portrait"],
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
"default": "landscape"
}
"default": "landscape",
},
},
"required": ["prompt"]
}
"required": ["prompt"],
},
}

View file

@ -77,7 +77,7 @@ import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
@ -108,6 +110,7 @@ try:
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
@ -118,27 +121,36 @@ except ImportError:
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
_SAFE_ENV_KEYS = frozenset(
{
"PATH",
"HOME",
"USER",
"LANG",
"LC_ALL",
"TERM",
"SHELL",
"TMPDIR",
}
)
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile(
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
def _build_safe_env(user_env: dict | None) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str:
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
@ -216,18 +230,22 @@ class SamplingHandler:
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
config.get("max_tool_rounds", 5),
5,
int,
minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
str(config.get("log_level", "info")).lower(),
logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._rate_timestamps: list[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
@ -245,7 +263,7 @@ class SamplingHandler:
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
def _resolve_model(self, preferences) -> str | None:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
@ -265,7 +283,7 @@ class SamplingHandler:
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
def _convert_messages(self, params) -> list[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
@ -273,37 +291,47 @@ class SamplingHandler:
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
messages: list[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
blocks = (
msg.content_as_list
if hasattr(msg, "content_as_list")
else (msg.content if isinstance(msg.content, list) else [msg.content])
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
tool_uses = [
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
]
content_blocks = [
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
messages.append(
{
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
}
)
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
tc_list.append(
{
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
}
)
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
@ -320,10 +348,12 @@ class SamplingHandler:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
parts.append(
{
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
}
)
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
@ -352,16 +382,13 @@ class SamplingHandler:
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
)
content_blocks = []
@ -372,25 +399,28 @@ class SamplingHandler:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
self.server_name,
args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
content_blocks.append(
ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
)
)
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
@ -410,7 +440,8 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
@ -445,12 +476,12 @@ class SamplingHandler:
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
self.server_name,
self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
)
# Resolve model
@ -458,6 +489,7 @@ class SamplingHandler:
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
@ -469,7 +501,8 @@ class SamplingHandler:
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
self.server_name,
resolved_model,
)
self.metrics["errors"] += 1
return self._error(
@ -515,7 +548,10 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
self.server_name,
resolved_model,
max_tokens,
len(messages),
)
# Offload sync LLM call to thread (non-blocking)
@ -524,19 +560,15 @@ class SamplingHandler:
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
asyncio.to_thread(_sync_call),
timeout=self.timeout,
)
except asyncio.TimeoutError:
except TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
# Track metrics
choice = response.choices[0]
@ -546,11 +578,7 @@ class SamplingHandler:
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
@ -560,6 +588,7 @@ class SamplingHandler:
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
class MCPServerTask:
"""Manages a single MCP server connection in a dedicated asyncio Task.
@ -571,22 +600,29 @@ class MCPServerTask:
"""
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"name",
"session",
"tool_timeout",
"_task",
"_ready",
"_shutdown_event",
"_tools",
"_error",
"_config",
"_sampling",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.session: Any | None = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._task: asyncio.Task | None = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._error: Exception | None = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
self._sampling: SamplingHandler | None = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@ -599,9 +635,7 @@ class MCPServerTask:
user_env = config.get("env")
if not command:
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
@ -650,11 +684,7 @@ class MCPServerTask:
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
@ -704,24 +734,28 @@ class MCPServerTask:
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
self.name,
exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
self.name,
_MAX_RECONNECT_RETRIES,
exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
self.name,
retries,
_MAX_RECONNECT_RETRIES,
backoff,
exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
@ -745,7 +779,7 @@ class MCPServerTask:
if self._task and not self._task.done():
try:
await asyncio.wait_for(self._task, timeout=10)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"MCP server '%s' shutdown timed out, cancelling task",
self.name,
@ -762,11 +796,11 @@ class MCPServerTask:
# Module-level state
# ---------------------------------------------------------------------------
_servers: Dict[str, MCPServerTask] = {}
_servers: dict[str, MCPServerTask] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
_mcp_loop: asyncio.AbstractEventLoop | None = None
_mcp_thread: threading.Thread | None = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
_lock = threading.Lock()
@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
# Config loading
# ---------------------------------------------------------------------------
def _load_mcp_config() -> Dict[str, dict]:
def _load_mcp_config() -> dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: server_config}`` or empty dict.
@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]:
"""
try:
from hermes_cli.config import load_config
config = load_config()
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]:
# Server connection helper
# ---------------------------------------------------------------------------
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
for block in (result.content or []):
for block in result.content or []:
if hasattr(block, "text"):
error_text += block.text
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
# Collect text from content blocks
parts: List[str] = []
for block in (result.content or []):
parts: list[str] = []
for block in result.content or []:
if hasattr(block, "text"):
parts.append(block.text)
return json.dumps({"result": "\n".join(parts) if parts else ""})
@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
except Exception as exc:
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
server_name,
tool_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_resources()
resources = []
for r in (result.resources if hasattr(result, "resources") else []):
for r in result.resources if hasattr(result, "resources") else []:
entry = {}
if hasattr(r, "uri"):
entry["uri"] = str(r.uri)
@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_resources failed: %s", server_name, exc,
"MCP %s/list_resources failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
uri = args.get("uri")
if not uri:
@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
async def _call():
result = await server.session.read_resource(uri)
# read_resource returns ReadResourceResult with .contents list
parts: List[str] = []
parts: list[str] = []
contents = result.contents if hasattr(result, "contents") else []
for block in contents:
if hasattr(block, "text"):
@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/read_resource failed: %s", server_name, exc,
"MCP %s/read_resource failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_prompts()
prompts = []
for p in (result.prompts if hasattr(result, "prompts") else []):
for p in result.prompts if hasattr(result, "prompts") else []:
entry = {}
if hasattr(p, "name"):
entry["name"] = p.name
@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_prompts failed: %s", server_name, exc,
"MCP %s/list_prompts failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
name = args.get("name")
if not name:
@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
result = await server.session.get_prompt(name, arguments=arguments)
# GetPromptResult has .messages list
messages = []
for msg in (result.messages if hasattr(result, "messages") else []):
for msg in result.messages if hasattr(result, "messages") else []:
entry = {}
if hasattr(msg, "role"):
entry["role"] = msg.role
@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/get_prompt failed: %s", server_name, exc,
"MCP %s/get_prompt failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str):
# Discovery & registration
# ---------------------------------------------------------------------------
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
"""Convert an MCP tool listing to the Hermes registry schema format.
@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
return {
"name": prefixed_name,
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
"parameters": mcp_tool.inputSchema
if mcp_tool.inputSchema
else {
"type": "object",
"properties": {},
},
}
def _build_utility_schemas(server_name: str) -> List[dict]:
def _build_utility_schemas(server_name: str) -> list[dict]:
"""Build schemas for the MCP utility tools (resources & prompts).
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
]
def _existing_tool_names() -> List[str]:
def _existing_tool_names() -> list[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
names: list[str] = []
for sname, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]:
return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
"""Connect to a single MCP server, discover tools, and register them.
Also registers utility tools for MCP Resources and Prompts support
@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
with _lock:
_servers[name] = server
registered_names: List[str] = []
registered_names: list[str] = []
toolset_name = f"mcp-{name}"
for mcp_tool in server._tools:
@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
name,
transport_type,
len(registered_names),
", ".join(registered_names),
)
return registered_names
@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
# Public API
# ---------------------------------------------------------------------------
def discover_mcp_tools() -> List[str]:
def discover_mcp_tools() -> list[str]:
"""Entry point: load config, connect to MCP servers, register tools.
Called from ``model_tools._discover_tools()``. Safe to call even when
@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]:
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
all_tools: list[str] = []
failed_count = 0
async def _discover_one(name: str, cfg: dict) -> List[str]:
async def _discover_one(name: str, cfg: dict) -> list[str]:
"""Connect to a single server and return its registered tool names."""
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
try:
registered = await _discover_and_register_server(name, cfg)
transport_type = "HTTP" if "url" in cfg else "stdio"
@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]:
except Exception as exc:
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
name,
exc,
)
return []
@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]:
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
from toolsets import TOOLSETS
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in all_tools:
@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]:
return _existing_tool_names()
def get_mcp_status() -> List[dict]:
def get_mcp_status() -> list[dict]:
"""Return status of all configured MCP servers for banner display.
Returns a list of dicts with keys: name, transport, tools, connected.
Includes both successfully connected servers and configured-but-failed ones.
"""
result: List[dict] = []
result: list[dict] = []
# Get configured servers from config
configured = _load_mcp_config()
@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
})
result.append(
{
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
}
)
return result
@ -1440,7 +1464,9 @@ def shutdown_mcp_servers():
for server, result in zip(servers_snapshot, results):
if isinstance(result, Exception):
logger.debug(
"Error closing MCP server '%s': %s", server.name, result,
"Error closing MCP server '%s': %s",
server.name,
result,
)
with _lock:
_servers.clear()

View file

@ -29,7 +29,7 @@ import os
import re
import tempfile
from pathlib import Path
from typing import Dict, Any, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -46,30 +46,38 @@ ENTRY_DELIMITER = "\n§\n"
_MEMORY_THREAT_PATTERNS = [
# Prompt injection
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
(r'you\s+are\s+now\s+', "role_hijack"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
(r"you\s+are\s+now\s+", "role_hijack"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
# Exfiltration via curl/wget with secrets
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
# Persistence via shell rc
(r'authorized_keys', "ssh_backdoor"),
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
(r"authorized_keys", "ssh_backdoor"),
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
(r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"),
]
# Subset of invisible chars for injection detection
_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
def _scan_memory_content(content: str) -> Optional[str]:
def _scan_memory_content(content: str) -> str | None:
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
# Check invisible unicode
for char in _INVISIBLE_CHARS:
@ -96,12 +104,12 @@ class MemoryStore:
"""
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
self.memory_entries: List[str] = []
self.user_entries: List[str] = []
self.memory_entries: list[str] = []
self.user_entries: list[str] = []
self.memory_char_limit = memory_char_limit
self.user_char_limit = user_char_limit
# Frozen snapshot for system prompt -- set once at load_from_disk()
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
def load_from_disk(self):
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
@ -129,12 +137,12 @@ class MemoryStore:
elif target == "user":
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
def _entries_for(self, target: str) -> List[str]:
def _entries_for(self, target: str) -> list[str]:
if target == "user":
return self.user_entries
return self.memory_entries
def _set_entries(self, target: str, entries: List[str]):
def _set_entries(self, target: str, entries: list[str]):
if target == "user":
self.user_entries = entries
else:
@ -151,7 +159,7 @@ class MemoryStore:
return self.user_char_limit
return self.memory_char_limit
def add(self, target: str, content: str) -> Dict[str, Any]:
def add(self, target: str, content: str) -> dict[str, Any]:
"""Append a new entry. Returns error if it would exceed the char limit."""
content = content.strip()
if not content:
@ -192,7 +200,7 @@ class MemoryStore:
return self._success_response(target, "Entry added.")
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
"""Find entry containing old_text substring, replace it with new_content."""
old_text = old_text.strip()
new_content = new_content.strip()
@ -247,7 +255,7 @@ class MemoryStore:
return self._success_response(target, "Entry replaced.")
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
def remove(self, target: str, old_text: str) -> dict[str, Any]:
"""Remove the entry containing old_text substring."""
old_text = old_text.strip()
if not old_text:
@ -278,7 +286,7 @@ class MemoryStore:
return self._success_response(target, "Entry removed.")
def format_for_system_prompt(self, target: str) -> Optional[str]:
def format_for_system_prompt(self, target: str) -> str | None:
"""
Return the frozen snapshot for system prompt injection.
@ -293,7 +301,7 @@ class MemoryStore:
# -- Internal helpers --
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
def _success_response(self, target: str, message: str = None) -> dict[str, Any]:
entries = self._entries_for(target)
current = self._char_count(target)
limit = self._char_limit(target)
@ -310,7 +318,7 @@ class MemoryStore:
resp["message"] = message
return resp
def _render_block(self, target: str, entries: List[str]) -> str:
def _render_block(self, target: str, entries: list[str]) -> str:
"""Render a system prompt block with header and usage indicator."""
if not entries:
return ""
@ -329,7 +337,7 @@ class MemoryStore:
return f"{separator}\n{header}\n{separator}\n{content}"
@staticmethod
def _read_file(path: Path) -> List[str]:
def _read_file(path: Path) -> list[str]:
"""Read a memory file and split into entries.
No file locking needed: _write_file uses atomic rename, so readers
@ -339,7 +347,7 @@ class MemoryStore:
return []
try:
raw = path.read_text(encoding="utf-8")
except (OSError, IOError):
except OSError:
return []
if not raw.strip():
@ -351,7 +359,7 @@ class MemoryStore:
return [e for e in entries if e]
@staticmethod
def _write_file(path: Path, entries: List[str]):
def _write_file(path: Path, entries: list[str]):
"""Write entries to a memory file using atomic temp-file + rename.
Previous implementation used open("w") + flock, but "w" truncates the
@ -362,9 +370,7 @@ class MemoryStore:
content = ENTRY_DELIMITER.join(entries) if entries else ""
try:
# Write to temp file in same directory (same filesystem for atomic rename)
fd, tmp_path = tempfile.mkstemp(
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
)
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(content)
@ -378,7 +384,7 @@ class MemoryStore:
except OSError:
pass
raise
except (OSError, IOError) as e:
except OSError as e:
raise RuntimeError(f"Failed to write memory file {path}: {e}")
@ -387,7 +393,7 @@ def memory_tool(
target: str = "memory",
content: str = None,
old_text: str = None,
store: Optional[MemoryStore] = None,
store: MemoryStore | None = None,
) -> str:
"""
Single entry point for the memory tool. Dispatches to MemoryStore methods.
@ -395,10 +401,15 @@ def memory_tool(
Returns JSON string with results.
"""
if store is None:
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "Memory is not available. It may be disabled in config or this environment."},
ensure_ascii=False,
)
if target not in ("memory", "user"):
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False
)
if action == "add":
if not content:
@ -407,18 +418,26 @@ def memory_tool(
elif action == "replace":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False
)
if not content:
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False
)
result = store.replace(target, old_text, content)
elif action == "remove":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False
)
result = store.remove(target, old_text)
else:
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False
)
return json.dumps(result, ensure_ascii=False)
@ -457,23 +476,16 @@ MEMORY_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["add", "replace", "remove"],
"description": "The action to perform."
},
"action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."},
"target": {
"type": "string",
"enum": ["memory", "user"],
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
},
"content": {
"type": "string",
"description": "The entry content. Required for 'add' and 'replace'."
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile.",
},
"content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."},
"old_text": {
"type": "string",
"description": "Short unique substring identifying the entry to replace or remove."
"description": "Short unique substring identifying the entry to replace or remove.",
},
},
"required": ["action", "target"],
@ -493,10 +505,7 @@ registry.register(
target=args.get("target", "memory"),
content=args.get("content"),
old_text=args.get("old_text"),
store=kw.get("store")),
store=kw.get("store"),
),
check_fn=check_memory_requirements,
)

View file

@ -38,21 +38,27 @@ Configuration:
Usage:
from mixture_of_agents_tool import mixture_of_agents_tool
import asyncio
# Process a complex query
result = await mixture_of_agents_tool(
user_prompt="Solve this complex mathematical proof..."
)
"""
import asyncio
import datetime
import json
import logging
import os
import asyncio
import datetime
from typing import Dict, Any, List, Optional
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
from typing import Any
from tools.debug_helpers import DebugSession
from tools.openrouter_client import (
check_api_key as check_openrouter_api_key,
)
from tools.openrouter_client import (
get_async_client as _get_openrouter_client,
)
logger = logging.getLogger(__name__)
@ -60,9 +66,9 @@ logger = logging.getLogger(__name__)
# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
REFERENCE_MODELS = [
"anthropic/claude-opus-4.5",
"google/gemini-3-pro-preview",
"google/gemini-3-pro-preview",
"openai/gpt-5.2-pro",
"deepseek/deepseek-v3.2"
"deepseek/deepseek-v3.2",
]
# Aggregator model - synthesizes reference responses into final output
@ -83,18 +89,18 @@ Responses from models:"""
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str:
"""
Construct the final system prompt for the aggregator including all model responses.
Args:
system_prompt (str): Base system prompt for aggregation
responses (List[str]): List of responses from reference models
Returns:
str: Complete system prompt with enumerated responses
"""
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)])
return f"{system_prompt}\n\n{response_text}"
@ -103,48 +109,43 @@ async def _run_reference_model_safe(
user_prompt: str,
temperature: float = REFERENCE_TEMPERATURE,
max_tokens: int = 32000,
max_retries: int = 6
max_retries: int = 6,
) -> tuple[str, str, bool]:
"""
Run a single reference model with retry logic and graceful failure handling.
Args:
model (str): Model identifier to use
user_prompt (str): The user's query
temperature (float): Sampling temperature for response generation
max_tokens (int): Maximum tokens in response
max_retries (int): Maximum number of retry attempts
Returns:
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
"""
for attempt in range(max_retries):
try:
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
# Build parameters for the API call
api_params = {
"model": model,
"messages": [{"role": "user", "content": user_prompt}],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not model.lower().startswith('gpt-'):
if not model.lower().startswith("gpt-"):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("%s responded (%s characters)", model, len(content))
return model, content, True
except Exception as e:
error_str = str(e)
# Log more detailed error information for debugging
@ -154,7 +155,7 @@ async def _run_reference_model_safe(
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
else:
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
if attempt < max_retries - 1:
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
sleep_time = min(2 ** (attempt + 1), 60)
@ -167,60 +168,47 @@ async def _run_reference_model_safe(
async def _run_aggregator_model(
system_prompt: str,
user_prompt: str,
temperature: float = AGGREGATOR_TEMPERATURE,
max_tokens: int = None
system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None
) -> str:
"""
Run the aggregator model to synthesize the final response.
Args:
system_prompt (str): System prompt with all reference responses
user_prompt (str): Original user query
temperature (float): Focused temperature for consistent aggregation
max_tokens (int): Maximum tokens in final response
Returns:
str: Synthesized final response
"""
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
# Build parameters for the API call
api_params = {
"model": AGGREGATOR_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
if not AGGREGATOR_MODEL.lower().startswith("gpt-"):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("Aggregation complete (%s characters)", len(content))
return content
async def mixture_of_agents_tool(
user_prompt: str,
reference_models: Optional[List[str]] = None,
aggregator_model: Optional[str] = None
user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None
) -> str:
"""
Process a complex query using the Mixture-of-Agents methodology.
This tool leverages multiple frontier language models to collaboratively solve
extremely difficult problems requiring intense reasoning. It's particularly
effective for:
@ -229,16 +217,16 @@ async def mixture_of_agents_tool(
- Multi-step analytical reasoning tasks
- Problems requiring diverse domain expertise
- Tasks where single models show limitations
The MoA approach uses a fixed 2-layer architecture:
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
Args:
user_prompt (str): The complex query or problem to solve
reference_models (Optional[List[str]]): Custom reference models to use
aggregator_model (Optional[str]): Custom aggregator model to use
Returns:
str: JSON string containing the MoA results with the following structure:
{
@ -250,12 +238,12 @@ async def mixture_of_agents_tool(
},
"processing_time": float
}
Raises:
Exception: If MoA processing fails or API key is not set
"""
start_time = datetime.datetime.now()
debug_call_data = {
"parameters": {
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
@ -263,7 +251,7 @@ async def mixture_of_agents_tool(
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
"reference_temperature": REFERENCE_TEMPERATURE,
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
},
"error": None,
"success": False,
@ -272,161 +260,152 @@ async def mixture_of_agents_tool(
"failed_models": [],
"final_response_length": 0,
"processing_time_seconds": 0,
"models_used": {}
"models_used": {},
}
try:
logger.info("Starting Mixture-of-Agents processing...")
logger.info("Query: %s", user_prompt[:100])
# Validate API key availability
if not os.getenv("OPENROUTER_API_KEY"):
raise ValueError("OPENROUTER_API_KEY environment variable not set")
# Use provided models or defaults
ref_models = reference_models or REFERENCE_MODELS
agg_model = aggregator_model or AGGREGATOR_MODEL
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
# Layer 1: Generate diverse responses from reference models (with failure handling)
logger.info("Layer 1: Generating reference responses...")
model_results = await asyncio.gather(*[
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
for model in ref_models
])
model_results = await asyncio.gather(
*[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models]
)
# Separate successful and failed responses
successful_responses = []
failed_models = []
for model_name, content, success in model_results:
if success:
successful_responses.append(content)
else:
failed_models.append(model_name)
successful_count = len(successful_responses)
failed_count = len(failed_models)
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
if failed_models:
logger.warning("Failed models: %s", ', '.join(failed_models))
logger.warning("Failed models: %s", ", ".join(failed_models))
# Check if we have enough successful responses to proceed
if successful_count < MIN_SUCCESSFUL_REFERENCES:
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
raise ValueError(
f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses."
)
debug_call_data["reference_responses_count"] = successful_count
debug_call_data["failed_models_count"] = failed_count
debug_call_data["failed_models"] = failed_models
# Layer 2: Aggregate responses using the aggregator model
logger.info("Layer 2: Synthesizing final response...")
aggregator_system_prompt = _construct_aggregator_prompt(
AGGREGATOR_SYSTEM_PROMPT,
successful_responses
)
final_response = await _run_aggregator_model(
aggregator_system_prompt,
user_prompt,
AGGREGATOR_TEMPERATURE
)
aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses)
final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE)
# Calculate processing time
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
logger.info("MoA processing completed in %.2f seconds", processing_time)
# Prepare successful response (only final aggregated result, minimal fields)
result = {
"success": True,
"response": final_response,
"models_used": {
"reference_models": ref_models,
"aggregator_model": agg_model
}
"models_used": {"reference_models": ref_models, "aggregator_model": agg_model},
}
debug_call_data["success"] = True
debug_call_data["final_response_length"] = len(final_response)
debug_call_data["processing_time_seconds"] = processing_time
debug_call_data["models_used"] = result["models_used"]
# Log debug information
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"Error in MoA processing: {str(e)}"
logger.error("%s", error_msg)
# Calculate processing time even for errors
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
# Prepare error response (minimal fields)
result = {
"success": False,
"response": "MoA processing failed. Please try again or use a single model for this query.",
"models_used": {
"reference_models": reference_models or REFERENCE_MODELS,
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
},
"error": error_msg
"error": error_msg,
}
debug_call_data["error"] = error_msg
debug_call_data["processing_time_seconds"] = processing_time
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
def check_moa_requirements() -> bool:
"""
Check if all requirements for MoA tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
return check_openrouter_api_key()
def get_debug_session_info() -> Dict[str, Any]:
def get_debug_session_info() -> dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
return _debug.get_session_info()
def get_available_models() -> Dict[str, List[str]]:
def get_available_models() -> dict[str, list[str]]:
"""
Get information about available models for MoA processing.
Returns:
Dict[str, List[str]]: Dictionary with reference and aggregator models
"""
return {
"reference_models": REFERENCE_MODELS,
"aggregator_models": [AGGREGATOR_MODEL],
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL],
}
def get_moa_configuration() -> Dict[str, Any]:
def get_moa_configuration() -> dict[str, Any]:
"""
Get the current MoA configuration settings.
Returns:
Dict[str, Any]: Dictionary containing all configuration parameters
"""
@ -437,7 +416,7 @@ def get_moa_configuration() -> Dict[str, Any]:
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
"total_reference_models": len(REFERENCE_MODELS),
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail",
}
@ -447,10 +426,10 @@ if __name__ == "__main__":
"""
print("🤖 Mixture-of-Agents Tool Module")
print("=" * 50)
# Check if API key is available
api_available = check_openrouter_api_key()
if not api_available:
print("❌ OPENROUTER_API_KEY environment variable not set")
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
@ -458,26 +437,26 @@ if __name__ == "__main__":
exit(1)
else:
print("✅ OpenRouter API key found")
print("🛠️ MoA tools ready for use!")
# Show current configuration
config = get_moa_configuration()
print(f"\n⚙️ Current Configuration:")
print("\n⚙️ Current Configuration:")
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
# Show debug mode status
if _debug.active:
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
else:
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
print(" import asyncio")
@ -488,24 +467,26 @@ if __name__ == "__main__":
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nBest use cases:")
print(" - Complex mathematical proofs and calculations")
print(" - Advanced coding problems and algorithm design")
print(" - Multi-step analytical reasoning tasks")
print(" - Problems requiring diverse domain expertise")
print(" - Tasks where single models show limitations")
print("\nPerformance characteristics:")
print(" - Higher latency due to multiple model calls")
print(" - Significantly improved quality for complex tasks")
print(" - Parallel processing for efficiency")
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
print(
f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation"
)
print(" - Token-efficient: only returns final aggregated response")
print(" - Resilient: continues with partial model failures")
print(f" - Configurable: easy to modify models and settings at top of file")
print(" - Configurable: easy to modify models and settings at top of file")
print(" - State-of-the-art results on challenging benchmarks")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export MOA_TOOLS_DEBUG=true")
@ -526,11 +507,11 @@ MOA_SCHEMA = {
"properties": {
"user_prompt": {
"type": "string",
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.",
}
},
"required": ["user_prompt"]
}
"required": ["user_prompt"],
},
}
registry.register(

View file

@ -1,7 +1,7 @@
"""Shared OpenRouter API client for Hermes tools.
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
can share, eliminating the duplicated _get_openrouter_client() /
can share, eliminating the duplicated _get_openrouter_client() /
_get_summarizer_client() pattern previously copy-pasted across web_tools,
vision_tools, mixture_of_agents_tool, and session_search_tool.
"""
@ -9,6 +9,7 @@ vision_tools, mixture_of_agents_tool, and session_search_tool.
import os
from openai import AsyncOpenAI
from hermes_constants import OPENROUTER_BASE_URL
_client: AsyncOpenAI | None = None

View file

@ -20,7 +20,7 @@ V4A Format:
Usage:
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
operations, error = parse_v4a_patch(patch_content)
if error:
print(f"Parse error: {error}")
@ -30,8 +30,8 @@ Usage:
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Any
from enum import Enum
from typing import Any
class OperationType(Enum):
@ -44,6 +44,7 @@ class OperationType(Enum):
@dataclass
class HunkLine:
"""A single line in a patch hunk."""
prefix: str # ' ', '-', or '+'
content: str
@ -51,182 +52,174 @@ class HunkLine:
@dataclass
class Hunk:
"""A group of changes within a file."""
context_hint: Optional[str] = None
lines: List[HunkLine] = field(default_factory=list)
context_hint: str | None = None
lines: list[HunkLine] = field(default_factory=list)
@dataclass
class PatchOperation:
"""A single operation in a V4A patch."""
operation: OperationType
file_path: str
new_path: Optional[str] = None # For move operations
hunks: List[Hunk] = field(default_factory=list)
content: Optional[str] = None # For add file operations
new_path: str | None = None # For move operations
hunks: list[Hunk] = field(default_factory=list)
content: str | None = None # For add file operations
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]:
"""
Parse a V4A format patch.
Args:
patch_content: The patch text in V4A format
Returns:
Tuple of (operations, error_message)
- If successful: (list_of_operations, None)
- If failed: ([], error_description)
"""
lines = patch_content.split('\n')
operations: List[PatchOperation] = []
lines = patch_content.split("\n")
operations: list[PatchOperation] = []
# Find patch boundaries
start_idx = None
end_idx = None
for i, line in enumerate(lines):
if '*** Begin Patch' in line or '***Begin Patch' in line:
if "*** Begin Patch" in line or "***Begin Patch" in line:
start_idx = i
elif '*** End Patch' in line or '***End Patch' in line:
elif "*** End Patch" in line or "***End Patch" in line:
end_idx = i
break
if start_idx is None:
# Try to parse without explicit begin marker
start_idx = -1
if end_idx is None:
end_idx = len(lines)
# Parse operations between boundaries
i = start_idx + 1
current_op: Optional[PatchOperation] = None
current_hunk: Optional[Hunk] = None
current_op: PatchOperation | None = None
current_hunk: Hunk | None = None
while i < end_idx:
line = lines[i]
# Check for file operation markers
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line)
add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line)
delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line)
move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line)
if update_match:
# Save previous operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.UPDATE,
file_path=update_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip())
current_hunk = None
elif add_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.ADD,
file_path=add_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip())
current_hunk = Hunk()
elif delete_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.DELETE,
file_path=delete_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip())
operations.append(current_op)
current_op = None
current_hunk = None
elif move_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.MOVE,
file_path=move_match.group(1).strip(),
new_path=move_match.group(2).strip()
new_path=move_match.group(2).strip(),
)
operations.append(current_op)
current_op = None
current_hunk = None
elif line.startswith('@@'):
elif line.startswith("@@"):
# Context hint / hunk marker
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
# Extract context hint
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
hint_match = re.match(r"@@\s*(.+?)\s*@@", line)
hint = hint_match.group(1) if hint_match else None
current_hunk = Hunk(context_hint=hint)
elif current_op and line:
# Parse hunk line
if current_hunk is None:
current_hunk = Hunk()
if line.startswith('+'):
current_hunk.lines.append(HunkLine('+', line[1:]))
elif line.startswith('-'):
current_hunk.lines.append(HunkLine('-', line[1:]))
elif line.startswith(' '):
current_hunk.lines.append(HunkLine(' ', line[1:]))
elif line.startswith('\\'):
if line.startswith("+"):
current_hunk.lines.append(HunkLine("+", line[1:]))
elif line.startswith("-"):
current_hunk.lines.append(HunkLine("-", line[1:]))
elif line.startswith(" "):
current_hunk.lines.append(HunkLine(" ", line[1:]))
elif line.startswith("\\"):
# "\ No newline at end of file" marker - skip
pass
else:
# Treat as context line (implicit space prefix)
current_hunk.lines.append(HunkLine(' ', line))
current_hunk.lines.append(HunkLine(" ", line))
i += 1
# Don't forget the last operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
return operations, None
def apply_v4a_operations(operations: List[PatchOperation],
file_ops: Any) -> 'PatchResult':
def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult":
"""
Apply V4A patch operations using a file operations interface.
Args:
operations: List of PatchOperation from parse_v4a_patch
file_ops: Object with read_file, write_file methods
Returns:
PatchResult with results of all operations
"""
# Import here to avoid circular imports
from tools.file_operations import PatchResult
files_modified = []
files_created = []
files_deleted = []
all_diffs = []
errors = []
for op in operations:
try:
if op.operation == OperationType.ADD:
@ -236,7 +229,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to add {op.file_path}: {result[1]}")
elif op.operation == OperationType.DELETE:
result = _apply_delete(op, file_ops)
if result[0]:
@ -244,7 +237,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
elif op.operation == OperationType.MOVE:
result = _apply_move(op, file_ops)
if result[0]:
@ -252,7 +245,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to move {op.file_path}: {result[1]}")
elif op.operation == OperationType.UPDATE:
result = _apply_update(op, file_ops)
if result[0]:
@ -260,19 +253,19 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to update {op.file_path}: {result[1]}")
except Exception as e:
errors.append(f"Error processing {op.file_path}: {str(e)}")
# Run lint on all modified/created files
lint_results = {}
for f in files_modified + files_created:
if hasattr(file_ops, '_check_lint'):
if hasattr(file_ops, "_check_lint"):
lint_result = file_ops._check_lint(f)
lint_results[f] = lint_result.to_dict()
combined_diff = '\n'.join(all_diffs)
combined_diff = "\n".join(all_diffs)
if errors:
return PatchResult(
success=False,
@ -281,123 +274,124 @@ def apply_v4a_operations(operations: List[PatchOperation],
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None,
error='; '.join(errors)
error="; ".join(errors),
)
return PatchResult(
success=True,
diff=combined_diff,
files_modified=files_modified,
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None
lint=lint_results if lint_results else None,
)
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply an add file operation."""
# Extract content from hunks (all + lines)
content_lines = []
for hunk in op.hunks:
for line in hunk.lines:
if line.prefix == '+':
if line.prefix == "+":
content_lines.append(line.content)
content = '\n'.join(content_lines)
content = "\n".join(content_lines)
result = file_ops.write_file(op.file_path, content)
if result.error:
return False, result.error
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
diff += '\n'.join(f"+{line}" for line in content_lines)
diff += "\n".join(f"+{line}" for line in content_lines)
return True, diff
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply a delete file operation."""
# Read file first for diff
read_result = file_ops.read_file(op.file_path)
if read_result.error and "not found" in read_result.error.lower():
# File doesn't exist, nothing to delete
return True, f"# {op.file_path} already deleted or doesn't exist"
# Delete directly via shell command using the underlying environment
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
if rm_result.exit_code != 0:
return False, rm_result.stdout
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
return True, diff
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply a move file operation."""
# Use shell mv command
mv_result = file_ops._exec(
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
)
if mv_result.exit_code != 0:
return False, mv_result.stdout
diff = f"# Moved: {op.file_path} -> {op.new_path}"
return True, diff
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply an update file operation."""
# Read current content
read_result = file_ops.read_file(op.file_path, limit=10000)
if read_result.error:
return False, f"Cannot read file: {read_result.error}"
# Parse content (remove line numbers)
current_lines = []
for line in read_result.content.split('\n'):
if '|' in line:
for line in read_result.content.split("\n"):
if "|" in line:
# Line format: " 123|content"
parts = line.split('|', 1)
parts = line.split("|", 1)
if len(parts) == 2:
current_lines.append(parts[1])
else:
current_lines.append(line)
else:
current_lines.append(line)
current_content = '\n'.join(current_lines)
current_content = "\n".join(current_lines)
# Apply each hunk
new_content = current_content
for hunk in op.hunks:
# Build search pattern from context and removed lines
search_lines = []
replace_lines = []
for line in hunk.lines:
if line.prefix == ' ':
if line.prefix == " ":
search_lines.append(line.content)
replace_lines.append(line.content)
elif line.prefix == '-':
elif line.prefix == "-":
search_lines.append(line.content)
elif line.prefix == '+':
elif line.prefix == "+":
replace_lines.append(line.content)
if search_lines:
search_pattern = '\n'.join(search_lines)
replacement = '\n'.join(replace_lines)
search_pattern = "\n".join(search_lines)
replacement = "\n".join(replace_lines)
# Use fuzzy matching
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, count, error = fuzzy_find_and_replace(
new_content, search_pattern, replacement, replace_all=False
)
if error and count == 0:
# Try with context hint if available
if hunk.context_hint:
@ -408,31 +402,32 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
window_start = max(0, hint_pos - 500)
window_end = min(len(new_content), hint_pos + 2000)
window = new_content[window_start:window_end]
window_new, count, error = fuzzy_find_and_replace(
window, search_pattern, replacement, replace_all=False
)
if count > 0:
new_content = new_content[:window_start] + window_new + new_content[window_end:]
error = None
if error:
return False, f"Could not apply hunk: {error}"
# Write new content
write_result = file_ops.write_file(op.file_path, new_content)
if write_result.error:
return False, write_result.error
# Generate diff
import difflib
diff_lines = difflib.unified_diff(
current_content.splitlines(keepends=True),
new_content.splitlines(keepends=True),
fromfile=f"a/{op.file_path}",
tofile=f"b/{op.file_path}"
tofile=f"b/{op.file_path}",
)
diff = ''.join(diff_lines)
diff = "".join(diff_lines)
return True, diff

View file

@ -34,7 +34,6 @@ import logging
import os
import platform
import shlex
import shutil
import signal
import subprocess
import threading
@ -42,10 +41,11 @@ import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.local import _find_shell
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from tools.environments.local import _find_shell
logger = logging.getLogger(__name__)
@ -54,30 +54,31 @@ logger = logging.getLogger(__name__)
CHECKPOINT_PATH = Path(os.path.expanduser("~/.hermes/processes.json"))
# Limits
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
@dataclass
class ProcessSession:
"""A tracked background process with output buffering."""
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
command: str # Original command string
task_id: str = "" # Task/sandbox isolation key
session_key: str = "" # Gateway session key (for reset protection)
pid: Optional[int] = None # OS process ID
process: Optional[subprocess.Popen] = None # Popen handle (local only)
env_ref: Any = None # Reference to the environment object
cwd: Optional[str] = None # Working directory
started_at: float = 0.0 # time.time() of spawn
exited: bool = False # Whether the process has finished
exit_code: Optional[int] = None # Exit code (None if still running)
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
command: str # Original command string
task_id: str = "" # Task/sandbox isolation key
session_key: str = "" # Gateway session key (for reset protection)
pid: int | None = None # OS process ID
process: subprocess.Popen | None = None # Popen handle (local only)
env_ref: Any = None # Reference to the environment object
cwd: str | None = None # Working directory
started_at: float = 0.0 # time.time() of spawn
exited: bool = False # Whether the process has finished
exit_code: int | None = None # Exit code (None if still running)
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe)
detached: bool = False # True if recovered from crash (no pipe)
_lock: threading.Lock = field(default_factory=threading.Lock)
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
_reader_thread: threading.Thread | None = field(default=None, repr=False)
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
@ -100,12 +101,12 @@ class ProcessRegistry:
)
def __init__(self):
self._running: Dict[str, ProcessSession] = {}
self._finished: Dict[str, ProcessSession] = {}
self._running: dict[str, ProcessSession] = {}
self._finished: dict[str, ProcessSession] = {}
self._lock = threading.Lock()
# Side-channel for check_interval watchers (gateway reads after agent run)
self.pending_watchers: List[Dict[str, Any]] = []
self.pending_watchers: list[dict[str, Any]] = []
@staticmethod
def _clean_shell_noise(text: str) -> str:
@ -149,6 +150,7 @@ class ProcessRegistry:
# Try PTY mode for interactive CLI tools
try:
import ptyprocess
user_shell = _find_shell()
pty_env = os.environ | (env_vars or {})
pty_env["PYTHONUNBUFFERED"] = "1"
@ -260,10 +262,7 @@ class ProcessRegistry:
log_path = f"/tmp/hermes_bg_{session.id}.log"
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
quoted_command = shlex.quote(command)
bg_command = (
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
f"echo $! > {pid_path} && cat {pid_path}"
)
bg_command = f"nohup bash -c {quoted_command} > {log_path} 2>&1 & echo $! > {pid_path} && cat {pid_path}"
try:
result = env.execute(bg_command, timeout=timeout)
@ -313,7 +312,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer += chunk
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
except Exception as e:
logger.debug("Process stdout reader ended: %s", e)
@ -326,9 +325,7 @@ class ProcessRegistry:
session.exit_code = session.process.returncode
self._move_to_finished(session)
def _env_poller_loop(
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
):
def _env_poller_loop(self, session: ProcessSession, env: Any, log_path: str, pid_path: str):
"""Background thread: poll a sandbox log file for non-local backends."""
while not session.exited:
time.sleep(2) # Poll every 2 seconds
@ -340,7 +337,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer = new_output
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
# Check if process is still running
check = env.execute(
@ -383,7 +380,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer += text
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
except EOFError:
break
except Exception:
@ -397,7 +394,7 @@ class ProcessRegistry:
except Exception as e:
logger.debug("PTY wait timed out or failed: %s", e)
session.exited = True
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
session.exit_code = pty.exitstatus if hasattr(pty, "exitstatus") else -1
self._move_to_finished(session)
def _move_to_finished(self, session: ProcessSession):
@ -409,7 +406,7 @@ class ProcessRegistry:
# ----- Query Methods -----
def get(self, session_id: str) -> Optional[ProcessSession]:
def get(self, session_id: str) -> ProcessSession | None:
"""Get a session by ID (running or finished)."""
with self._lock:
return self._running.get(session_id) or self._finished.get(session_id)
@ -454,7 +451,7 @@ class ProcessRegistry:
if offset == 0 and limit > 0:
selected = lines[-limit:]
else:
selected = lines[offset:offset + limit]
selected = lines[offset : offset + limit]
return {
"session_id": session.id,
@ -485,10 +482,7 @@ class ProcessRegistry:
if requested_timeout and requested_timeout > max_timeout:
effective_timeout = max_timeout
timeout_note = (
f"Requested wait of {requested_timeout}s was clamped "
f"to configured limit of {max_timeout}s"
)
timeout_note = f"Requested wait of {requested_timeout}s was clamped to configured limit of {max_timeout}s"
else:
effective_timeout = requested_timeout or max_timeout
@ -581,7 +575,7 @@ class ProcessRegistry:
return {"status": "already_exited", "error": "Process has already finished"}
# PTY mode -- write through pty handle (expects bytes)
if hasattr(session, '_pty') and session._pty:
if hasattr(session, "_pty") and session._pty:
try:
pty_data = data.encode("utf-8") if isinstance(data, str) else data
session._pty.write(pty_data)
@ -635,26 +629,17 @@ class ProcessRegistry:
def has_active_processes(self, task_id: str) -> bool:
"""Check if there are active (running) processes for a task_id."""
with self._lock:
return any(
s.task_id == task_id and not s.exited
for s in self._running.values()
)
return any(s.task_id == task_id and not s.exited for s in self._running.values())
def has_active_for_session(self, session_key: str) -> bool:
"""Check if there are active processes for a gateway session key."""
with self._lock:
return any(
s.session_key == session_key and not s.exited
for s in self._running.values()
)
return any(s.session_key == session_key and not s.exited for s in self._running.values())
def kill_all(self, task_id: str = None) -> int:
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
with self._lock:
targets = [
s for s in self._running.values()
if (task_id is None or s.task_id == task_id) and not s.exited
]
targets = [s for s in self._running.values() if (task_id is None or s.task_id == task_id) and not s.exited]
killed = 0
for session in targets:
@ -669,10 +654,7 @@ class ProcessRegistry:
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
# First prune expired finished sessions
now = time.time()
expired = [
sid for sid, s in self._finished.items()
if (now - s.started_at) > FINISHED_TTL_SECONDS
]
expired = [sid for sid, s in self._finished.items() if (now - s.started_at) > FINISHED_TTL_SECONDS]
for sid in expired:
del self._finished[sid]
@ -696,18 +678,21 @@ class ProcessRegistry:
entries = []
for s in self._running.values():
if not s.exited:
entries.append({
"session_id": s.id,
"command": s.command,
"pid": s.pid,
"cwd": s.cwd,
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
})
entries.append(
{
"session_id": s.id,
"command": s.command,
"pid": s.pid,
"cwd": s.cwd,
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
}
)
# Atomic write to avoid corruption on crash
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, entries)
except Exception as e:
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
@ -759,6 +744,7 @@ class ProcessRegistry:
# Clear the checkpoint (will be rewritten as processes finish)
try:
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, [])
except Exception as e:
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
@ -790,38 +776,32 @@ PROCESS_SCHEMA = {
"action": {
"type": "string",
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
"description": "Action to perform on background processes"
"description": "Action to perform on background processes",
},
"session_id": {
"type": "string",
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
"description": "Process session ID (from terminal background output). Required for all actions except 'list'.",
},
"data": {
"type": "string",
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
"description": "Text to send to process stdin (for 'write' and 'submit' actions)",
},
"timeout": {
"type": "integer",
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
"minimum": 1
"minimum": 1,
},
"offset": {
"type": "integer",
"description": "Line offset for 'log' action (default: last 200 lines)"
},
"limit": {
"type": "integer",
"description": "Max lines to return for 'log' action",
"minimum": 1
}
"offset": {"type": "integer", "description": "Line offset for 'log' action (default: last 200 lines)"},
"limit": {"type": "integer", "description": "Max lines to return for 'log' action", "minimum": 1},
},
"required": ["action"]
}
"required": ["action"],
},
}
def _handle_process(args, **kw):
import json as _json
task_id = kw.get("task_id")
action = args.get("action", "")
# Coerce to string — some models send session_id as an integer
@ -835,8 +815,10 @@ def _handle_process(args, **kw):
if action == "poll":
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
elif action == "log":
return _json.dumps(process_registry.read_log(
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
return _json.dumps(
process_registry.read_log(session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)),
ensure_ascii=False,
)
elif action == "wait":
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
elif action == "kill":
@ -845,7 +827,10 @@ def _handle_process(args, **kw):
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
elif action == "submit":
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
return _json.dumps(
{"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"},
ensure_ascii=False,
)
registry.register(

View file

@ -16,7 +16,7 @@ Import chain (circular-import safe):
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Set
from collections.abc import Callable
logger = logging.getLogger(__name__)
@ -25,12 +25,17 @@ class ToolEntry:
"""Metadata for a single registered tool."""
__slots__ = (
"name", "toolset", "schema", "handler", "check_fn",
"requires_env", "is_async", "description",
"name",
"toolset",
"schema",
"handler",
"check_fn",
"requires_env",
"is_async",
"description",
)
def __init__(self, name, toolset, schema, handler, check_fn,
requires_env, is_async, description):
def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description):
self.name = name
self.toolset = toolset
self.schema = schema
@ -45,8 +50,8 @@ class ToolRegistry:
"""Singleton registry that collects tool schemas + handlers from tool files."""
def __init__(self):
self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {}
self._tools: dict[str, ToolEntry] = {}
self._toolset_checks: dict[str, Callable] = {}
# ------------------------------------------------------------------
# Registration
@ -81,7 +86,7 @@ class ToolRegistry:
# Schema retrieval
# ------------------------------------------------------------------
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
def get_definitions(self, tool_names: set[str], quiet: bool = False) -> list[dict]:
"""Return OpenAI-format tool schemas for the requested tool names.
Only tools whose ``check_fn()`` returns True (or have no check_fn)
@ -122,6 +127,7 @@ class ToolRegistry:
try:
if entry.is_async:
from model_tools import _run_async
return _run_async(entry.handler(args, **kwargs))
return entry.handler(args, **kwargs)
except Exception as e:
@ -132,16 +138,16 @@ class ToolRegistry:
# Query helpers (replace redundant dicts in model_tools.py)
# ------------------------------------------------------------------
def get_all_tool_names(self) -> List[str]:
def get_all_tool_names(self) -> list[str]:
"""Return sorted list of all registered tool names."""
return sorted(self._tools.keys())
def get_toolset_for_tool(self, name: str) -> Optional[str]:
def get_toolset_for_tool(self, name: str) -> str | None:
"""Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name)
return entry.toolset if entry else None
def get_tool_to_toolset_map(self) -> Dict[str, str]:
def get_tool_to_toolset_map(self) -> dict[str, str]:
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
return {name: e.toolset for name, e in self._tools.items()}
@ -160,14 +166,14 @@ class ToolRegistry:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def check_toolset_requirements(self) -> Dict[str, bool]:
def check_toolset_requirements(self) -> dict[str, bool]:
"""Return ``{toolset: available_bool}`` for every toolset."""
toolsets = set(e.toolset for e in self._tools.values())
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
def get_available_toolsets(self) -> Dict[str, dict]:
def get_available_toolsets(self) -> dict[str, dict]:
"""Return toolset metadata for UI display."""
toolsets: Dict[str, dict] = {}
toolsets: dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in toolsets:
@ -184,9 +190,9 @@ class ToolRegistry:
toolsets[ts]["requirements"].append(env)
return toolsets
def get_toolset_requirements(self) -> Dict[str, dict]:
def get_toolset_requirements(self) -> dict[str, dict]:
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
result: Dict[str, dict] = {}
result: dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in result:
@ -217,11 +223,13 @@ class ToolRegistry:
if self.is_toolset_available(ts):
available.append(ts)
else:
unavailable.append({
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
})
unavailable.append(
{
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
}
)
return available, unavailable

File diff suppressed because it is too large Load diff

View file

@ -29,19 +29,16 @@ SEND_MESSAGE_SCHEMA = {
"action": {
"type": "string",
"enum": ["send", "list"],
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms.",
},
"target": {
"type": "string",
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'",
},
"message": {
"type": "string",
"description": "The message text to send"
}
"message": {"type": "string", "description": "The message text to send"},
},
"required": []
}
"required": [],
},
}
@ -59,6 +56,7 @@ def _handle_list():
"""Return formatted list of available messaging targets."""
try:
from gateway.channel_directory import format_directory_for_display
return json.dumps({"targets": format_directory_for_display()})
except Exception as e:
return json.dumps({"error": f"Failed to load channel directory: {e}"})
@ -79,26 +77,30 @@ def _handle_send(args):
if chat_id and not chat_id.lstrip("-").isdigit():
try:
from gateway.channel_directory import resolve_channel_name
resolved = resolve_channel_name(platform_name, chat_id)
if resolved:
chat_id = resolved
else:
return json.dumps({
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Use send_message(action='list') to see available targets."
})
return json.dumps(
{
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Use send_message(action='list') to see available targets."
}
)
except Exception:
return json.dumps({
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Try using a numeric channel ID instead."
})
return json.dumps(
{"error": f"Could not resolve '{chat_id}' on {platform_name}. Try using a numeric channel ID instead."}
)
from tools.interrupt import is_interrupted
if is_interrupted():
return json.dumps({"error": "Interrupted"})
try:
from gateway.config import load_gateway_config, Platform
from gateway.config import Platform, load_gateway_config
config = load_gateway_config()
except Exception as e:
return json.dumps({"error": f"Failed to load gateway config: {e}"})
@ -117,7 +119,11 @@ def _handle_send(args):
pconfig = config.platforms.get(platform)
if not pconfig or not pconfig.enabled:
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."})
return json.dumps(
{
"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."
}
)
used_home_channel = False
if not chat_id:
@ -126,14 +132,17 @@ def _handle_send(args):
chat_id = home.chat_id
used_home_channel = True
else:
return json.dumps({
"error": f"No home channel set for {platform_name} to determine where to send the message. "
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
})
return json.dumps(
{
"error": f"No home channel set for {platform_name} to determine where to send the message. "
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
}
)
try:
from model_tools import _run_async
result = _run_async(_send_to_platform(platform, pconfig, chat_id, message))
if used_home_channel and isinstance(result, dict) and result.get("success"):
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
@ -142,6 +151,7 @@ def _handle_send(args):
if isinstance(result, dict) and result.get("success"):
try:
from gateway.mirror import mirror_to_session
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
if mirror_to_session(platform_name, chat_id, message, source_label=source_label):
result["mirrored"] = True
@ -156,6 +166,7 @@ def _handle_send(args):
async def _send_to_platform(platform, pconfig, chat_id, message):
"""Route a message to the appropriate platform sender."""
from gateway.config import Platform
if platform == Platform.TELEGRAM:
return await _send_telegram(pconfig.token, chat_id, message)
elif platform == Platform.DISCORD:
@ -171,6 +182,7 @@ async def _send_telegram(token, chat_id, message):
"""Send via Telegram Bot API (one-shot, no polling needed)."""
try:
from telegram import Bot
bot = Bot(token=token)
msg = await bot.send_message(chat_id=int(chat_id), text=message)
return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)}
@ -189,7 +201,7 @@ async def _send_discord(token, chat_id, message):
try:
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
chunks = [message[i:i+2000] for i in range(0, len(message), 2000)]
chunks = [message[i : i + 2000] for i in range(0, len(message), 2000)]
message_ids = []
async with aiohttp.ClientSession() as session:
for chunk in chunks:
@ -266,6 +278,7 @@ def _check_send_message():
return True
try:
from gateway.status import is_gateway_running
return is_gateway_running()
except Exception:
return False

View file

@ -18,11 +18,8 @@ Flow:
import asyncio
import concurrent.futures
import json
import os
import logging
from typing import Dict, Any, List, Optional, Union
from openai import AsyncOpenAI, OpenAI
from typing import Any
from agent.auxiliary_client import get_async_text_auxiliary_client
@ -33,7 +30,7 @@ MAX_SESSION_CHARS = 100_000
MAX_SUMMARY_TOKENS = 10000
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
def _format_timestamp(ts: int | float | str | None) -> str:
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
Returns "unknown" for None, str(ts) if conversion fails.
@ -43,11 +40,13 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
try:
if isinstance(ts, (int, float)):
from datetime import datetime
dt = datetime.fromtimestamp(ts)
return dt.strftime("%B %d, %Y at %I:%M %p")
if isinstance(ts, str):
if ts.replace(".", "").replace("-", "").isdigit():
from datetime import datetime
dt = datetime.fromtimestamp(float(ts))
return dt.strftime("%B %d, %Y at %I:%M %p")
return ts
@ -59,7 +58,7 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
return str(ts)
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
def _format_conversation(messages: list[dict[str, Any]]) -> str:
"""Format session messages into a readable transcript for summarization."""
parts = []
for msg in messages:
@ -93,9 +92,7 @@ def _format_conversation(messages: List[Dict[str, Any]]) -> str:
return "\n\n".join(parts)
def _truncate_around_matches(
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
) -> str:
def _truncate_around_matches(full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS) -> str:
"""
Truncate a conversation transcript to max_chars, centered around
where the query terms appear. Keeps content near matches, trims the edges.
@ -129,9 +126,7 @@ def _truncate_around_matches(
return prefix + truncated + suffix
async def _summarize_session(
conversation_text: str, query: str, session_meta: Dict[str, Any]
) -> Optional[str]:
async def _summarize_session(conversation_text: str, query: str, session_meta: dict[str, Any]) -> str | None:
"""Summarize a single session conversation focused on the search query."""
system_prompt = (
"You are reviewing a past conversation transcript to help recall what happened. "
@ -163,7 +158,8 @@ async def _summarize_session(
max_retries = 3
for attempt in range(max_retries):
try:
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
_extra = get_auxiliary_extra_body()
response = await _async_aux_client.chat.completions.create(
model=_SUMMARIZER_MODEL,
@ -221,13 +217,16 @@ def session_search(
)
if not raw_results:
return json.dumps({
"success": True,
"query": query,
"results": [],
"count": 0,
"message": "No matching sessions found.",
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"query": query,
"results": [],
"count": 0,
"message": "No matching sessions found.",
},
ensure_ascii=False,
)
# Resolve child sessions to their parent — delegation stores detailed
# content in child sessions, but the user's conversation is the parent.
@ -283,12 +282,9 @@ def session_search(
logging.warning(f"Failed to prepare session {session_id}: {e}")
# Summarize all sessions in parallel
async def _summarize_all() -> List[Union[str, Exception]]:
async def _summarize_all() -> list[str | Exception]:
"""Summarize all sessions in parallel."""
coros = [
_summarize_session(text, query, meta)
for _, _, text, meta in tasks
]
coros = [_summarize_session(text, query, meta) for _, _, text, meta in tasks]
return await asyncio.gather(*coros, return_exceptions=True)
try:
@ -300,10 +296,13 @@ def session_search(
results = asyncio.run(_summarize_all())
except concurrent.futures.TimeoutError:
logging.warning("Session summarization timed out after 60 seconds")
return json.dumps({
"success": False,
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
},
ensure_ascii=False,
)
summaries = []
for (session_id, match_info, _, _), result in zip(tasks, results):
@ -311,21 +310,26 @@ def session_search(
logging.warning(f"Failed to summarize session {session_id}: {result}")
continue
if result:
summaries.append({
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": result,
})
summaries.append(
{
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": result,
}
)
return json.dumps({
"success": True,
"query": query,
"results": summaries,
"count": len(summaries),
"sessions_searched": len(seen_sessions),
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"query": query,
"results": summaries,
"count": len(summaries),
"sessions_searched": len(seen_sessions),
},
ensure_ascii=False,
)
except Exception as e:
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
@ -337,6 +341,7 @@ def check_session_search_requirements() -> bool:
return False
try:
from hermes_state import DEFAULT_DB_PATH
return DEFAULT_DB_PATH.parent.exists()
except ImportError:
return False
@ -356,7 +361,7 @@ SESSION_SEARCH_SCHEMA = {
"Don't hesitate to search -- it's fast and cheap. Better to search and confirm "
"than to guess or ask the user to repeat themselves.\n\n"
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
'phrases for exact match ("docker networking"), boolean (python NOT java), prefix (deploy*). '
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
"keyword searches in parallel. Returns summaries of the top matching sessions."
@ -395,6 +400,7 @@ registry.register(
role_filter=args.get("role_filter"),
limit=args.get("limit", 3),
db=kw.get("db"),
current_session_id=kw.get("current_session_id")),
current_session_id=kw.get("current_session_id"),
),
check_fn=check_session_search_requirements,
)

View file

@ -38,20 +38,21 @@ import os
import re
import shutil
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
# Import security scanner — agent-created skills get the same scrutiny as
# community hub installs.
try:
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
_GUARD_AVAILABLE = True
except ImportError:
_GUARD_AVAILABLE = False
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
def _security_scan_skill(skill_dir: Path) -> str | None:
"""Scan a skill directory after write. Returns error string if blocked, else None."""
if not _GUARD_AVAILABLE:
return None
@ -65,8 +66,8 @@ def _security_scan_skill(skill_dir: Path) -> Optional[str]:
logger.warning("Security scan failed for %s: %s", skill_dir, e)
return None
import yaml
import yaml
# All skills live in ~/.hermes/skills/ (single source of truth)
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
@ -76,7 +77,7 @@ MAX_NAME_LENGTH = 64
MAX_DESCRIPTION_LENGTH = 1024
# Characters allowed in skill names (filesystem-safe, URL-friendly)
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
VALID_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9._-]*$")
# Subdirectories allowed for write_file/remove_file
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
@ -91,7 +92,8 @@ def check_skill_manage_requirements() -> bool:
# Validation helpers
# =============================================================================
def _validate_name(name: str) -> Optional[str]:
def _validate_name(name: str) -> str | None:
"""Validate a skill name. Returns error message or None if valid."""
if not name:
return "Skill name is required."
@ -105,7 +107,7 @@ def _validate_name(name: str) -> Optional[str]:
return None
def _validate_frontmatter(content: str) -> Optional[str]:
def _validate_frontmatter(content: str) -> str | None:
"""
Validate that SKILL.md content has proper frontmatter with required fields.
Returns error message or None if valid.
@ -116,11 +118,11 @@ def _validate_frontmatter(content: str) -> Optional[str]:
if not content.startswith("---"):
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
end_match = re.search(r'\n---\s*\n', content[3:])
end_match = re.search(r"\n---\s*\n", content[3:])
if not end_match:
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
yaml_content = content[3:end_match.start() + 3]
yaml_content = content[3 : end_match.start() + 3]
try:
parsed = yaml.safe_load(yaml_content)
@ -137,7 +139,7 @@ def _validate_frontmatter(content: str) -> Optional[str]:
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
body = content[end_match.end() + 3:].strip()
body = content[end_match.end() + 3 :].strip()
if not body:
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
@ -151,7 +153,7 @@ def _resolve_skill_dir(name: str, category: str = None) -> Path:
return SKILLS_DIR / name
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
def _find_skill(name: str) -> dict[str, Any] | None:
"""
Find a skill by name in ~/.hermes/skills/.
Returns {"path": Path} or None.
@ -164,7 +166,7 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]:
return None
def _validate_file_path(file_path: str) -> Optional[str]:
def _validate_file_path(file_path: str) -> str | None:
"""
Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir.
@ -194,7 +196,8 @@ def _validate_file_path(file_path: str) -> Optional[str]:
# Core actions
# =============================================================================
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
def _create_skill(name: str, content: str, category: str = None) -> dict[str, Any]:
"""Create a new user skill with SKILL.md content."""
# Validate name
err = _validate_name(name)
@ -209,10 +212,7 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
# Check for name collisions across all directories
existing = _find_skill(name)
if existing:
return {
"success": False,
"error": f"A skill named '{name}' already exists at {existing['path']}."
}
return {"success": False, "error": f"A skill named '{name}' already exists at {existing['path']}."}
# Create the skill directory
skill_dir = _resolve_skill_dir(name, category)
@ -238,12 +238,12 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
result["category"] = category
result["hint"] = (
"To add reference files, templates, or scripts, use "
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
f"skill_manage(action='write_file', name='{name}', file_path='references/example.md', file_content='...')"
)
return result
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
def _edit_skill(name: str, content: str) -> dict[str, Any]:
"""Replace the SKILL.md of any existing skill (full rewrite)."""
err = _validate_frontmatter(content)
if err:
@ -278,7 +278,7 @@ def _patch_skill(
new_string: str,
file_path: str = None,
replace_all: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Targeted find-and-replace within a skill file.
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
@ -287,7 +287,10 @@ def _patch_skill(
if not old_string:
return {"success": False, "error": "old_string is required for 'patch'."}
if new_string is None:
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
return {
"success": False,
"error": "new_string is required for 'patch'. Use an empty string to delete matched text.",
}
existing = _find_skill(name)
if not existing:
@ -357,7 +360,7 @@ def _patch_skill(
}
def _delete_skill(name: str) -> Dict[str, Any]:
def _delete_skill(name: str) -> dict[str, Any]:
"""Delete a skill."""
existing = _find_skill(name)
if not existing:
@ -377,7 +380,7 @@ def _delete_skill(name: str) -> Dict[str, Any]:
}
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
def _write_file(name: str, file_path: str, file_content: str) -> dict[str, Any]:
"""Add or overwrite a supporting file within any skill directory."""
err = _validate_file_path(file_path)
if err:
@ -412,7 +415,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
}
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
def _remove_file(name: str, file_path: str) -> dict[str, Any]:
"""Remove a supporting file from any skill directory."""
err = _validate_file_path(file_path)
if err:
@ -456,6 +459,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
# Main entry point
# =============================================================================
def skill_manage(
action: str,
name: str,
@ -474,19 +478,37 @@ def skill_manage(
"""
if action == "create":
if not content:
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body).",
},
ensure_ascii=False,
)
result = _create_skill(name, content, category)
elif action == "edit":
if not content:
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."},
ensure_ascii=False,
)
result = _edit_skill(name, content)
elif action == "patch":
if not old_string:
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_string is required for 'patch'. Provide the text to find."},
ensure_ascii=False,
)
if new_string is None:
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "new_string is required for 'patch'. Use empty string to delete matched text.",
},
ensure_ascii=False,
)
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
elif action == "delete":
@ -494,18 +516,31 @@ def skill_manage(
elif action == "write_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'",
},
ensure_ascii=False,
)
if file_content is None:
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False
)
result = _write_file(name, file_path, file_content)
elif action == "remove_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False
)
result = _remove_file(name, file_path)
else:
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
result = {
"success": False,
"error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file",
}
return json.dumps(result, ensure_ascii=False)
@ -540,14 +575,14 @@ SKILL_MANAGE_SCHEMA = {
"action": {
"type": "string",
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
"description": "The action to perform."
"description": "The action to perform.",
},
"name": {
"type": "string",
"description": (
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
)
),
},
"content": {
"type": "string",
@ -555,7 +590,7 @@ SKILL_MANAGE_SCHEMA = {
"Full SKILL.md content (YAML frontmatter + markdown body). "
"Required for 'create' and 'edit'. For 'edit', read the skill "
"first with skill_view() and provide the complete updated text."
)
),
},
"old_string": {
"type": "string",
@ -563,18 +598,17 @@ SKILL_MANAGE_SCHEMA = {
"Text to find in the file (required for 'patch'). Must be unique "
"unless replace_all=true. Include enough surrounding context to "
"ensure uniqueness."
)
),
},
"new_string": {
"type": "string",
"description": (
"Replacement text (required for 'patch'). Can be empty string "
"to delete the matched text."
)
"Replacement text (required for 'patch'). Can be empty string to delete the matched text."
),
},
"replace_all": {
"type": "boolean",
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false).",
},
"category": {
"type": "string",
@ -582,7 +616,7 @@ SKILL_MANAGE_SCHEMA = {
"Optional category/domain for organizing the skill (e.g., 'devops', "
"'data-science', 'mlops'). Creates a subdirectory grouping. "
"Only used with 'create'."
)
),
},
"file_path": {
"type": "string",
@ -591,12 +625,9 @@ SKILL_MANAGE_SCHEMA = {
"For 'write_file'/'remove_file': required, must be under references/, "
"templates/, scripts/, or assets/. "
"For 'patch': optional, defaults to SKILL.md if omitted."
)
},
"file_content": {
"type": "string",
"description": "Content for the file. Required for 'write_file'."
),
},
"file_content": {"type": "string", "description": "Content for the file. Required for 'write_file'."},
},
"required": ["action", "name"],
},
@ -619,5 +650,6 @@ registry.register(
file_content=args.get("file_content"),
old_string=args.get("old_string"),
new_string=args.get("new_string"),
replace_all=args.get("replace_all", False)),
replace_all=args.get("replace_all", False),
),
)

File diff suppressed because it is too large Load diff

View file

@ -23,15 +23,17 @@ import subprocess
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import httpx
import yaml
from tools.skills_guard import (
ScanResult, scan_skill, should_allow_install, content_hash, TRUSTED_REPOS,
TRUSTED_REPOS,
ScanResult,
content_hash,
)
logger = logging.getLogger(__name__)
@ -58,24 +60,27 @@ INDEX_CACHE_TTL = 3600 # 1 hour
# Data models
# ---------------------------------------------------------------------------
@dataclass
class SkillMeta:
"""Minimal metadata returned by search results."""
name: str
description: str
source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub"
identifier: str # source-specific ID (e.g. "openai/skills/skill-creator")
trust_level: str # "builtin" | "trusted" | "community"
repo: Optional[str] = None
path: Optional[str] = None
tags: List[str] = field(default_factory=list)
source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub"
identifier: str # source-specific ID (e.g. "openai/skills/skill-creator")
trust_level: str # "builtin" | "trusted" | "community"
repo: str | None = None
path: str | None = None
tags: list[str] = field(default_factory=list)
@dataclass
class SkillBundle:
"""A downloaded skill ready for quarantine/scanning/installation."""
name: str
files: Dict[str, str] # relative_path -> text content
files: dict[str, str] # relative_path -> text content
source: str
identifier: str
trust_level: str
@ -85,6 +90,7 @@ class SkillBundle:
# GitHub Authentication
# ---------------------------------------------------------------------------
class GitHubAuth:
"""
GitHub API authentication. Tries methods in priority order:
@ -95,11 +101,11 @@ class GitHubAuth:
"""
def __init__(self):
self._cached_token: Optional[str] = None
self._cached_method: Optional[str] = None
self._cached_token: str | None = None
self._cached_method: str | None = None
self._app_token_expiry: float = 0
def get_headers(self) -> Dict[str, str]:
def get_headers(self) -> dict[str, str]:
"""Return authorization headers for GitHub API requests."""
token = self._resolve_token()
headers = {"Accept": "application/vnd.github.v3+json"}
@ -115,7 +121,7 @@ class GitHubAuth:
self._resolve_token()
return self._cached_method or "anonymous"
def _resolve_token(self) -> Optional[str]:
def _resolve_token(self) -> str | None:
# Return cached token if still valid
if self._cached_token:
if self._cached_method != "github-app" or time.time() < self._app_token_expiry:
@ -146,12 +152,14 @@ class GitHubAuth:
self._cached_method = "anonymous"
return None
def _try_gh_cli(self) -> Optional[str]:
def _try_gh_cli(self) -> str | None:
"""Try to get a token from the gh CLI."""
try:
result = subprocess.run(
["gh", "auth", "token"],
capture_output=True, text=True, timeout=5,
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
@ -159,7 +167,7 @@ class GitHubAuth:
logger.debug("gh CLI token lookup failed: %s", e)
return None
def _try_github_app(self) -> Optional[str]:
def _try_github_app(self) -> str | None:
"""Try GitHub App JWT authentication if credentials are configured."""
app_id = os.environ.get("GITHUB_APP_ID")
key_path = os.environ.get("GITHUB_APP_PRIVATE_KEY_PATH")
@ -208,21 +216,22 @@ class GitHubAuth:
# Source adapter interface
# ---------------------------------------------------------------------------
class SkillSource(ABC):
"""Abstract base for all skill registry adapters."""
@abstractmethod
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
"""Search for skills matching a query string."""
...
@abstractmethod
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
"""Download a skill bundle by identifier."""
...
@abstractmethod
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
"""Fetch metadata for a skill without downloading all files."""
...
@ -240,6 +249,7 @@ class SkillSource(ABC):
# GitHub source adapter
# ---------------------------------------------------------------------------
class GitHubSource(SkillSource):
"""Fetch skills from GitHub repos via the Contents API."""
@ -249,7 +259,7 @@ class GitHubSource(SkillSource):
{"repo": "VoltAgent/awesome-agent-skills", "path": "skills/"},
]
def __init__(self, auth: GitHubAuth, extra_taps: Optional[List[Dict]] = None):
def __init__(self, auth: GitHubAuth, extra_taps: list[dict] | None = None):
self.auth = auth
self.taps = list(self.DEFAULT_TAPS)
if extra_taps:
@ -267,9 +277,9 @@ class GitHubSource(SkillSource):
return "trusted"
return "community"
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
"""Search all taps for skills matching the query."""
results: List[SkillMeta] = []
results: list[SkillMeta] = []
query_lower = query.lower()
for tap in self.taps:
@ -287,15 +297,13 @@ class GitHubSource(SkillSource):
_trust_rank = {"builtin": 2, "trusted": 1, "community": 0}
seen = {}
for r in results:
if r.name not in seen:
seen[r.name] = r
elif _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0):
if r.name not in seen or _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0):
seen[r.name] = r
results = list(seen.values())
return results[:limit]
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
"""
Download a skill from GitHub.
identifier format: "owner/repo/path/to/skill-dir"
@ -322,7 +330,7 @@ class GitHubSource(SkillSource):
trust_level=trust,
)
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
"""Fetch just the SKILL.md metadata for preview."""
parts = identifier.split("/", 2)
if len(parts) < 3:
@ -363,7 +371,7 @@ class GitHubSource(SkillSource):
# -- Internal helpers --
def _list_skills_in_repo(self, repo: str, path: str) -> List[SkillMeta]:
def _list_skills_in_repo(self, repo: str, path: str) -> list[SkillMeta]:
"""List skill directories in a GitHub repo path, using cached index."""
cache_key = f"{repo}_{path}".replace("/", "_").replace(" ", "_")
cached = self._read_cache(cache_key)
@ -382,7 +390,7 @@ class GitHubSource(SkillSource):
if not isinstance(entries, list):
return []
skills: List[SkillMeta] = []
skills: list[SkillMeta] = []
for entry in entries:
if entry.get("type") != "dir":
continue
@ -400,7 +408,7 @@ class GitHubSource(SkillSource):
self._write_cache(cache_key, [self._meta_to_dict(s) for s in skills])
return skills
def _download_directory(self, repo: str, path: str) -> Dict[str, str]:
def _download_directory(self, repo: str, path: str) -> dict[str, str]:
"""Recursively download all text files from a GitHub directory."""
url = f"https://api.github.com/repos/{repo}/contents/{path.rstrip('/')}"
try:
@ -414,7 +422,7 @@ class GitHubSource(SkillSource):
if not isinstance(entries, list):
return {}
files: Dict[str, str] = {}
files: dict[str, str] = {}
for entry in entries:
name = entry.get("name", "")
entry_type = entry.get("type", "")
@ -431,7 +439,7 @@ class GitHubSource(SkillSource):
return files
def _fetch_file_content(self, repo: str, path: str) -> Optional[str]:
def _fetch_file_content(self, repo: str, path: str) -> str | None:
"""Fetch a single file's content from GitHub."""
url = f"https://api.github.com/repos/{repo}/contents/{path}"
try:
@ -446,7 +454,7 @@ class GitHubSource(SkillSource):
logger.debug("GitHub contents API fetch failed: %s", e)
return None
def _read_cache(self, key: str) -> Optional[list]:
def _read_cache(self, key: str) -> list | None:
"""Read cached index if not expired."""
cache_file = INDEX_CACHE_DIR / f"{key}.json"
if not cache_file.exists():
@ -486,10 +494,10 @@ class GitHubSource(SkillSource):
"""Parse YAML frontmatter from SKILL.md content."""
if not content.startswith("---"):
return {}
match = re.search(r'\n---\s*\n', content[3:])
match = re.search(r"\n---\s*\n", content[3:])
if not match:
return {}
yaml_text = content[3:match.start() + 3]
yaml_text = content[3 : match.start() + 3]
try:
parsed = yaml.safe_load(yaml_text)
return parsed if isinstance(parsed, dict) else {}
@ -501,6 +509,7 @@ class GitHubSource(SkillSource):
# ClawHub source adapter
# ---------------------------------------------------------------------------
class ClawHubSource(SkillSource):
"""
Fetch skills from ClawHub (clawhub.ai) via their HTTP API.
@ -516,7 +525,7 @@ class ClawHubSource(SkillSource):
def trust_level_for(self, identifier: str) -> str:
return "community"
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
cache_key = f"clawhub_search_{hashlib.md5(query.encode()).hexdigest()}"
cached = _read_index_cache(cache_key)
if cached is not None:
@ -548,19 +557,21 @@ class ClawHubSource(SkillSource):
tags = item.get("tags", [])
if not isinstance(tags, list):
tags = []
results.append(SkillMeta(
name=display_name,
description=summary,
source="clawhub",
identifier=slug,
trust_level="community",
tags=[str(t) for t in tags],
))
results.append(
SkillMeta(
name=display_name,
description=summary,
source="clawhub",
identifier=slug,
trust_level="community",
tags=[str(t) for t in tags],
)
)
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results])
return results
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
slug = identifier.split("/")[-1]
skill_data = self._get_json(f"{self.BASE_URL}/skills/{slug}")
@ -593,7 +604,7 @@ class ClawHubSource(SkillSource):
trust_level="community",
)
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
slug = identifier.split("/")[-1]
data = self._get_json(f"{self.BASE_URL}/skills/{slug}")
if not isinstance(data, dict):
@ -612,7 +623,7 @@ class ClawHubSource(SkillSource):
tags=[str(t) for t in tags],
)
def _get_json(self, url: str, timeout: int = 20) -> Optional[Any]:
def _get_json(self, url: str, timeout: int = 20) -> Any | None:
try:
resp = httpx.get(url, timeout=timeout)
if resp.status_code != 200:
@ -621,7 +632,7 @@ class ClawHubSource(SkillSource):
except (httpx.HTTPError, json.JSONDecodeError):
return None
def _resolve_latest_version(self, slug: str, skill_data: Dict[str, Any]) -> Optional[str]:
def _resolve_latest_version(self, slug: str, skill_data: dict[str, Any]) -> str | None:
latest = skill_data.get("latestVersion")
if isinstance(latest, dict):
version = latest.get("version")
@ -643,8 +654,8 @@ class ClawHubSource(SkillSource):
return version
return None
def _extract_files(self, version_data: Dict[str, Any]) -> Dict[str, str]:
files: Dict[str, str] = {}
def _extract_files(self, version_data: dict[str, Any]) -> dict[str, str]:
files: dict[str, str] = {}
file_list = version_data.get("files")
if isinstance(file_list, dict):
@ -674,7 +685,7 @@ class ClawHubSource(SkillSource):
return files
def _fetch_text(self, url: str) -> Optional[str]:
def _fetch_text(self, url: str) -> str | None:
try:
resp = httpx.get(url, timeout=20)
if resp.status_code == 200:
@ -688,6 +699,7 @@ class ClawHubSource(SkillSource):
# Claude Code marketplace source adapter
# ---------------------------------------------------------------------------
class ClaudeMarketplaceSource(SkillSource):
"""
Discover skills from Claude Code marketplace repos.
@ -713,8 +725,8 @@ class ClaudeMarketplaceSource(SkillSource):
return "trusted"
return "community"
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
results: List[SkillMeta] = []
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
results: list[SkillMeta] = []
query_lower = query.lower()
for marketplace_repo in self.KNOWN_MARKETPLACES:
@ -730,18 +742,20 @@ class ClaudeMarketplaceSource(SkillSource):
else:
identifier = f"{marketplace_repo}/{source_path}"
results.append(SkillMeta(
name=plugin.get("name", ""),
description=plugin.get("description", ""),
source="claude-marketplace",
identifier=identifier,
trust_level=self.trust_level_for(identifier),
repo=marketplace_repo,
))
results.append(
SkillMeta(
name=plugin.get("name", ""),
description=plugin.get("description", ""),
source="claude-marketplace",
identifier=identifier,
trust_level=self.trust_level_for(identifier),
repo=marketplace_repo,
)
)
return results[:limit]
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
# Delegate to GitHub Contents API since marketplace skills live in GitHub repos
gh = GitHubSource(auth=self.auth)
bundle = gh.fetch(identifier)
@ -749,7 +763,7 @@ class ClaudeMarketplaceSource(SkillSource):
bundle.source = "claude-marketplace"
return bundle
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
gh = GitHubSource(auth=self.auth)
meta = gh.inspect(identifier)
if meta:
@ -757,7 +771,7 @@ class ClaudeMarketplaceSource(SkillSource):
meta.trust_level = self.trust_level_for(identifier)
return meta
def _fetch_marketplace_index(self, repo: str) -> List[dict]:
def _fetch_marketplace_index(self, repo: str) -> list[dict]:
"""Fetch and parse .claude-plugin/marketplace.json from a repo."""
cache_key = f"claude_marketplace_{repo.replace('/', '_')}"
cached = _read_index_cache(cache_key)
@ -786,6 +800,7 @@ class ClaudeMarketplaceSource(SkillSource):
# LobeHub source adapter
# ---------------------------------------------------------------------------
class LobeHubSource(SkillSource):
"""
Fetch skills from LobeHub's agent marketplace (14,500+ agents).
@ -802,13 +817,13 @@ class LobeHubSource(SkillSource):
def trust_level_for(self, identifier: str) -> str:
return "community"
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
index = self._fetch_index()
if not index:
return []
query_lower = query.lower()
results: List[SkillMeta] = []
results: list[SkillMeta] = []
agents = index.get("agents", index) if isinstance(index, dict) else index
if not isinstance(agents, list):
@ -823,21 +838,23 @@ class LobeHubSource(SkillSource):
searchable = f"{title} {desc} {' '.join(tags) if isinstance(tags, list) else ''}".lower()
if query_lower in searchable:
identifier = agent.get("identifier", title.lower().replace(" ", "-"))
results.append(SkillMeta(
name=identifier,
description=desc[:200],
source="lobehub",
identifier=f"lobehub/{identifier}",
trust_level="community",
tags=tags if isinstance(tags, list) else [],
))
results.append(
SkillMeta(
name=identifier,
description=desc[:200],
source="lobehub",
identifier=f"lobehub/{identifier}",
trust_level="community",
tags=tags if isinstance(tags, list) else [],
)
)
if len(results) >= limit:
break
return results
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
# Strip "lobehub/" prefix if present
agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier
@ -854,7 +871,7 @@ class LobeHubSource(SkillSource):
trust_level="community",
)
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier
index = self._fetch_index()
if not index:
@ -877,7 +894,7 @@ class LobeHubSource(SkillSource):
)
return None
def _fetch_index(self) -> Optional[Any]:
def _fetch_index(self) -> Any | None:
"""Fetch the LobeHub agent index (cached for 1 hour)."""
cache_key = "lobehub_index"
cached = _read_index_cache(cache_key)
@ -895,7 +912,7 @@ class LobeHubSource(SkillSource):
_write_index_cache(cache_key, data)
return data
def _fetch_agent(self, agent_id: str) -> Optional[dict]:
def _fetch_agent(self, agent_id: str) -> dict | None:
"""Fetch a single agent's JSON file."""
url = f"https://chat-agents.lobehub.com/{agent_id}.json"
try:
@ -924,8 +941,8 @@ class LobeHubSource(SkillSource):
"metadata:",
" hermes:",
f" tags: [{', '.join(str(t) for t in tag_list)}]",
f" lobehub:",
f" source: lobehub",
" lobehub:",
" source: lobehub",
"---",
]
@ -946,6 +963,7 @@ class LobeHubSource(SkillSource):
# Official optional skills source adapter
# ---------------------------------------------------------------------------
class OptionalSkillSource(SkillSource):
"""
Fetch skills from the optional-skills/ directory shipped with the repo.
@ -967,8 +985,8 @@ class OptionalSkillSource(SkillSource):
# -- search -----------------------------------------------------------
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
results: List[SkillMeta] = []
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
results: list[SkillMeta] = []
query_lower = query.lower()
for meta in self._scan_all():
@ -982,7 +1000,7 @@ class OptionalSkillSource(SkillSource):
# -- fetch ------------------------------------------------------------
def fetch(self, identifier: str) -> Optional[SkillBundle]:
def fetch(self, identifier: str) -> SkillBundle | None:
# identifier format: "official/category/skill" or "official/skill"
rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier
skill_dir = self._optional_dir / rel
@ -1004,7 +1022,7 @@ class OptionalSkillSource(SkillSource):
else:
skill_dir = resolved
files: Dict[str, str] = {}
files: dict[str, str] = {}
for f in skill_dir.rglob("*"):
if f.is_file() and not f.name.startswith("."):
rel_path = str(f.relative_to(skill_dir))
@ -1029,7 +1047,7 @@ class OptionalSkillSource(SkillSource):
# -- inspect ----------------------------------------------------------
def inspect(self, identifier: str) -> Optional[SkillMeta]:
def inspect(self, identifier: str) -> SkillMeta | None:
rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier
skill_name = rel.rsplit("/", 1)[-1]
@ -1040,7 +1058,7 @@ class OptionalSkillSource(SkillSource):
# -- internal helpers -------------------------------------------------
def _find_skill_dir(self, name: str) -> Optional[Path]:
def _find_skill_dir(self, name: str) -> Path | None:
"""Find a skill directory by name anywhere in optional-skills/."""
if not self._optional_dir.is_dir():
return None
@ -1049,12 +1067,12 @@ class OptionalSkillSource(SkillSource):
return skill_md.parent
return None
def _scan_all(self) -> List[SkillMeta]:
def _scan_all(self) -> list[SkillMeta]:
"""Enumerate all optional skills with metadata."""
if not self._optional_dir.is_dir():
return []
results: List[SkillMeta] = []
results: list[SkillMeta] = []
for skill_md in sorted(self._optional_dir.rglob("SKILL.md")):
parent = skill_md.parent
rel_parts = parent.relative_to(self._optional_dir).parts
@ -1078,15 +1096,17 @@ class OptionalSkillSource(SkillSource):
rel_path = str(parent.relative_to(self._optional_dir))
results.append(SkillMeta(
name=name,
description=desc[:200],
source="official",
identifier=f"official/{rel_path}",
trust_level="builtin",
path=rel_path,
tags=tags if isinstance(tags, list) else [],
))
results.append(
SkillMeta(
name=name,
description=desc[:200],
source="official",
identifier=f"official/{rel_path}",
trust_level="builtin",
path=rel_path,
tags=tags if isinstance(tags, list) else [],
)
)
return results
@ -1095,10 +1115,10 @@ class OptionalSkillSource(SkillSource):
"""Parse YAML frontmatter from SKILL.md content."""
if not content.startswith("---"):
return {}
match = re.search(r'\n---\s*\n', content[3:])
match = re.search(r"\n---\s*\n", content[3:])
if not match:
return {}
yaml_text = content[3:match.start() + 3]
yaml_text = content[3 : match.start() + 3]
try:
parsed = yaml.safe_load(yaml_text)
return parsed if isinstance(parsed, dict) else {}
@ -1110,7 +1130,8 @@ class OptionalSkillSource(SkillSource):
# Shared cache helpers (used by multiple adapters)
# ---------------------------------------------------------------------------
def _read_index_cache(key: str) -> Optional[Any]:
def _read_index_cache(key: str) -> Any | None:
"""Read cached data if not expired."""
cache_file = INDEX_CACHE_DIR / f"{key}.json"
if not cache_file.exists():
@ -1152,6 +1173,7 @@ def _skill_meta_to_dict(meta: SkillMeta) -> dict:
# Lock file management
# ---------------------------------------------------------------------------
class HubLockFile:
"""Manages skills/.hub/lock.json — tracks provenance of installed hub skills."""
@ -1179,7 +1201,7 @@ class HubLockFile:
scan_verdict: str,
skill_hash: str,
install_path: str,
files: List[str],
files: list[str],
) -> None:
data = self.load()
data["installed"][name] = {
@ -1190,8 +1212,8 @@ class HubLockFile:
"content_hash": skill_hash,
"install_path": install_path,
"files": files,
"installed_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"installed_at": datetime.now(UTC).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
}
self.save(data)
@ -1200,11 +1222,11 @@ class HubLockFile:
data["installed"].pop(name, None)
self.save(data)
def get_installed(self, name: str) -> Optional[dict]:
def get_installed(self, name: str) -> dict | None:
data = self.load()
return data["installed"].get(name)
def list_installed(self) -> List[dict]:
def list_installed(self) -> list[dict]:
data = self.load()
result = []
for name, entry in data["installed"].items():
@ -1220,13 +1242,14 @@ class HubLockFile:
# Taps management
# ---------------------------------------------------------------------------
class TapsManager:
"""Manages the taps.json file — custom GitHub repo sources."""
def __init__(self, path: Path = TAPS_FILE):
self.path = path
def load(self) -> List[dict]:
def load(self) -> list[dict]:
if not self.path.exists():
return []
try:
@ -1235,7 +1258,7 @@ class TapsManager:
except (json.JSONDecodeError, OSError):
return []
def save(self, taps: List[dict]) -> None:
def save(self, taps: list[dict]) -> None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_text(json.dumps({"taps": taps}, indent=2) + "\n")
@ -1257,7 +1280,7 @@ class TapsManager:
self.save(new_taps)
return True
def list_taps(self) -> List[dict]:
def list_taps(self) -> list[dict]:
return self.load()
@ -1265,11 +1288,13 @@ class TapsManager:
# Audit log
# ---------------------------------------------------------------------------
def append_audit_log(action: str, skill_name: str, source: str,
trust_level: str, verdict: str, extra: str = "") -> None:
def append_audit_log(
action: str, skill_name: str, source: str, trust_level: str, verdict: str, extra: str = ""
) -> None:
"""Append a line to the audit log."""
AUDIT_LOG.parent.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
timestamp = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
parts = [timestamp, action, skill_name, f"{source}:{trust_level}", verdict]
if extra:
parts.append(extra)
@ -1285,6 +1310,7 @@ def append_audit_log(action: str, skill_name: str, source: str,
# Hub operations (high-level)
# ---------------------------------------------------------------------------
def ensure_hub_dirs() -> None:
"""Create the .hub directory structure if it doesn't exist."""
HUB_DIR.mkdir(parents=True, exist_ok=True)
@ -1347,15 +1373,18 @@ def install_from_quarantine(
)
append_audit_log(
"INSTALL", skill_name, bundle.source,
bundle.trust_level, scan_result.verdict,
"INSTALL",
skill_name,
bundle.source,
bundle.trust_level,
scan_result.verdict,
content_hash(install_dir),
)
return install_dir
def uninstall_skill(skill_name: str) -> Tuple[bool, str]:
def uninstall_skill(skill_name: str) -> tuple[bool, str]:
"""Remove a hub-installed skill. Refuses to remove builtins."""
lock = HubLockFile()
entry = lock.get_installed(skill_name)
@ -1372,7 +1401,7 @@ def uninstall_skill(skill_name: str) -> Tuple[bool, str]:
return True, f"Uninstalled '{skill_name}' from {entry['install_path']}"
def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]:
def create_source_router(auth: GitHubAuth | None = None) -> list[SkillSource]:
"""
Create all configured source adapters.
Returns a list of active sources for search/fetch operations.
@ -1383,8 +1412,8 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]
taps_mgr = TapsManager()
extra_taps = taps_mgr.list_taps()
sources: List[SkillSource] = [
OptionalSkillSource(), # Official optional skills (highest priority)
sources: list[SkillSource] = [
OptionalSkillSource(), # Official optional skills (highest priority)
GitHubSource(auth=auth, extra_taps=extra_taps),
ClawHubSource(),
ClaudeMarketplaceSource(auth=auth),
@ -1394,10 +1423,11 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]
return sources
def unified_search(query: str, sources: List[SkillSource],
source_filter: str = "all", limit: int = 10) -> List[SkillMeta]:
def unified_search(
query: str, sources: list[SkillSource], source_filter: str = "all", limit: int = 10
) -> list[SkillMeta]:
"""Search all sources and merge results."""
all_results: List[SkillMeta] = []
all_results: list[SkillMeta] = []
for src in sources:
if source_filter != "all" and src.source_id() != source_filter:
@ -1410,11 +1440,9 @@ def unified_search(query: str, sources: List[SkillSource],
# Deduplicate by name, preferring higher trust levels
_TRUST_RANK = {"builtin": 2, "trusted": 1, "community": 0}
seen: Dict[str, SkillMeta] = {}
seen: dict[str, SkillMeta] = {}
for r in all_results:
if r.name not in seen:
seen[r.name] = r
elif _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0):
if r.name not in seen or _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0):
seen[r.name] = r
deduped = list(seen.values())

View file

@ -26,7 +26,6 @@ import logging
import os
import shutil
from pathlib import Path
from typing import Dict, List, Tuple
logger = logging.getLogger(__name__)
@ -41,7 +40,7 @@ def _get_bundled_dir() -> Path:
return Path(__file__).parent.parent / "skills"
def _read_manifest() -> Dict[str, str]:
def _read_manifest() -> dict[str, str]:
"""
Read the manifest as a dict of {skill_name: origin_hash}.
@ -64,11 +63,11 @@ def _read_manifest() -> Dict[str, str]:
# v1 format: plain name — empty hash triggers migration
result[line] = ""
return result
except (OSError, IOError):
except OSError:
return {}
def _write_manifest(entries: Dict[str, str]):
def _write_manifest(entries: dict[str, str]):
"""Write the manifest file atomically in v2 format (name:hash).
Uses a temp file + os.replace() to avoid corruption if the process
@ -101,7 +100,7 @@ def _write_manifest(entries: Dict[str, str]):
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
def _discover_bundled_skills(bundled_dir: Path) -> list[tuple[str, Path]]:
"""
Find all SKILL.md files in the bundled directory.
Returns list of (skill_name, skill_directory_path) tuples.
@ -139,7 +138,7 @@ def _dir_hash(directory: Path) -> str:
rel = fpath.relative_to(directory)
hasher.update(str(rel).encode("utf-8"))
hasher.update(fpath.read_bytes())
except (OSError, IOError):
except OSError:
pass
return hasher.hexdigest()
@ -155,8 +154,12 @@ def sync_skills(quiet: bool = False) -> dict:
bundled_dir = _get_bundled_dir()
if not bundled_dir.exists():
return {
"copied": [], "updated": [], "skipped": 0,
"user_modified": [], "cleaned": [], "total_bundled": 0,
"copied": [],
"updated": [],
"skipped": 0,
"user_modified": [],
"cleaned": [],
"total_bundled": 0,
}
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
@ -187,7 +190,7 @@ def sync_skills(quiet: bool = False) -> dict:
manifest[skill_name] = bundled_hash
if not quiet:
print(f" + {skill_name}")
except (OSError, IOError) as e:
except OSError as e:
if not quiet:
print(f" ! Failed to copy {skill_name}: {e}")
# Do NOT add to manifest — next sync should retry
@ -229,12 +232,12 @@ def sync_skills(quiet: bool = False) -> dict:
print(f"{skill_name} (updated)")
# Remove backup after successful copy
shutil.rmtree(backup, ignore_errors=True)
except (OSError, IOError):
except OSError:
# Restore from backup
if backup.exists() and not dest.exists():
shutil.move(str(backup), str(dest))
raise
except (OSError, IOError) as e:
except OSError as e:
if not quiet:
print(f" ! Failed to update {skill_name}: {e}")
else:
@ -257,7 +260,7 @@ def sync_skills(quiet: bool = False) -> dict:
try:
dest_desc.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(desc_md, dest_desc)
except (OSError, IOError) as e:
except OSError as e:
logger.debug("Could not copy %s: %s", desc_md, e)
_write_manifest(manifest)

View file

@ -40,9 +40,9 @@ SKILL.md Format (YAML Frontmatter, agentskills.io compatible):
tags: [fine-tuning, llm]
related_skills: [peft, lora]
---
# Skill Title
Full instructions and content here...
Available tools:
@ -51,13 +51,13 @@ Available tools:
Usage:
from tools.skills_tool import skills_list, skill_view, check_skills_requirements
# List all skills (returns metadata only - token efficient)
result = skills_list()
# View a skill's main content (loads full instructions)
content = skill_view("axolotl")
# View a reference file within a skill (loads linked file)
content = skill_view("axolotl", "references/dataset-formats.md")
"""
@ -67,11 +67,10 @@ import os
import re
import sys
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from typing import Any
import yaml
# All skills live in ~/.hermes/skills/ (seeded from bundled skills/ on install).
# This is the single source of truth -- agent edits, hub installs, and bundled
# skills all coexist here without polluting the git repo.
@ -91,7 +90,7 @@ _PLATFORM_MAP = {
}
def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool:
def skill_matches_platform(frontmatter: dict[str, Any]) -> bool:
"""Check if a skill is compatible with the current OS platform.
Skills declare platform requirements via a top-level ``platforms`` list
@ -123,28 +122,28 @@ def check_skills_requirements() -> bool:
return True
def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]:
"""
Parse YAML frontmatter from markdown content.
Uses yaml.safe_load for full YAML support (nested metadata, lists, etc.)
with a fallback to simple key:value splitting for robustness.
Args:
content: Full markdown file content
Returns:
Tuple of (frontmatter dict, remaining content)
"""
frontmatter = {}
body = content
if content.startswith("---"):
end_match = re.search(r'\n---\s*\n', content[3:])
end_match = re.search(r"\n---\s*\n", content[3:])
if end_match:
yaml_content = content[3:end_match.start() + 3]
body = content[end_match.end() + 3:]
yaml_content = content[3 : end_match.start() + 3]
body = content[end_match.end() + 3 :]
try:
parsed = yaml.safe_load(yaml_content)
if isinstance(parsed, dict):
@ -152,18 +151,18 @@ def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
# yaml.safe_load returns None for empty frontmatter
except yaml.YAMLError:
# Fallback: simple key:value parsing for malformed YAML
for line in yaml_content.strip().split('\n'):
if ':' in line:
key, value = line.split(':', 1)
for line in yaml_content.strip().split("\n"):
if ":" in line:
key, value = line.split(":", 1)
frontmatter[key.strip()] = value.strip()
return frontmatter, body
def _get_category_from_path(skill_path: Path) -> Optional[str]:
def _get_category_from_path(skill_path: Path) -> str | None:
"""
Extract category from skill path based on directory structure.
For paths like: ~/.hermes/skills/mlops/axolotl/SKILL.md -> "mlops"
"""
try:
@ -179,134 +178,136 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
def _estimate_tokens(content: str) -> int:
"""
Rough token estimate (4 chars per token average).
Args:
content: Text content
Returns:
Estimated token count
"""
return len(content) // 4
def _parse_tags(tags_value) -> List[str]:
def _parse_tags(tags_value) -> list[str]:
"""
Parse tags from frontmatter value.
Handles:
- Already-parsed list (from yaml.safe_load): [tag1, tag2]
- String with brackets: "[tag1, tag2]"
- Comma-separated string: "tag1, tag2"
Args:
tags_value: Raw tags value may be a list or string
Returns:
List of tag strings
"""
if not tags_value:
return []
# yaml.safe_load already returns a list for [tag1, tag2]
if isinstance(tags_value, list):
return [str(t).strip() for t in tags_value if t]
# String fallback — handle bracket-wrapped or comma-separated
tags_value = str(tags_value).strip()
if tags_value.startswith('[') and tags_value.endswith(']'):
if tags_value.startswith("[") and tags_value.endswith("]"):
tags_value = tags_value[1:-1]
return [t.strip().strip('"\'') for t in tags_value.split(',') if t.strip()]
return [t.strip().strip("\"'") for t in tags_value.split(",") if t.strip()]
def _find_all_skills() -> List[Dict[str, Any]]:
def _find_all_skills() -> list[dict[str, Any]]:
"""
Recursively find all skills in ~/.hermes/skills/.
Returns metadata for progressive disclosure (tier 1):
- name, description, category
Returns:
List of skill metadata dicts
"""
skills = []
if not SKILLS_DIR.exists():
return skills
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
continue
skill_dir = skill_md.parent
try:
content = skill_md.read_text(encoding='utf-8')
content = skill_md.read_text(encoding="utf-8")
frontmatter, body = _parse_frontmatter(content)
# Skip skills incompatible with the current OS platform
if not skill_matches_platform(frontmatter):
continue
name = frontmatter.get('name', skill_dir.name)[:MAX_NAME_LENGTH]
description = frontmatter.get('description', '')
name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH]
description = frontmatter.get("description", "")
if not description:
for line in body.strip().split('\n'):
for line in body.strip().split("\n"):
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
description = line
break
if len(description) > MAX_DESCRIPTION_LENGTH:
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..."
category = _get_category_from_path(skill_md)
skills.append({
"name": name,
"description": description,
"category": category,
})
skills.append(
{
"name": name,
"description": description,
"category": category,
}
)
except Exception:
continue
return skills
def _load_category_description(category_dir: Path) -> Optional[str]:
def _load_category_description(category_dir: Path) -> str | None:
"""
Load category description from DESCRIPTION.md if it exists.
Args:
category_dir: Path to the category directory
Returns:
Description string or None if not found
"""
desc_file = category_dir / "DESCRIPTION.md"
if not desc_file.exists():
return None
try:
content = desc_file.read_text(encoding='utf-8')
content = desc_file.read_text(encoding="utf-8")
# Parse frontmatter if present
frontmatter, body = _parse_frontmatter(content)
# Prefer frontmatter description, fall back to first non-header line
description = frontmatter.get('description', '')
description = frontmatter.get("description", "")
if not description:
for line in body.strip().split('\n'):
for line in body.strip().split("\n"):
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
description = line
break
# Truncate to reasonable length
if len(description) > MAX_DESCRIPTION_LENGTH:
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..."
return description if description else None
except Exception:
return None
@ -315,26 +316,24 @@ def _load_category_description(category_dir: Path) -> Optional[str]:
def skills_categories(verbose: bool = False, task_id: str = None) -> str:
"""
List available skill categories with descriptions (progressive disclosure tier 0).
Returns category names and descriptions for efficient discovery before drilling down.
Categories can have a DESCRIPTION.md file with a description frontmatter field
or first paragraph to explain what skills are in that category.
Args:
verbose: If True, include skill counts per category (default: False, but currently always included)
task_id: Optional task identifier (unused, for API consistency)
Returns:
JSON string with list of categories and their descriptions
"""
try:
if not SKILLS_DIR.exists():
return json.dumps({
"success": True,
"categories": [],
"message": "No skills directory found."
}, ensure_ascii=False)
return json.dumps(
{"success": True, "categories": [], "message": "No skills directory found."}, ensure_ascii=False
)
category_dirs = {}
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
category = _get_category_from_path(skill_md)
@ -342,121 +341,125 @@ def skills_categories(verbose: bool = False, task_id: str = None) -> str:
category_dir = SKILLS_DIR / category
if category not in category_dirs:
category_dirs[category] = category_dir
categories = []
for name in sorted(category_dirs.keys()):
category_dir = category_dirs[name]
description = _load_category_description(category_dir)
skill_count = sum(1 for _ in category_dir.rglob("SKILL.md"))
cat_entry = {"name": name, "skill_count": skill_count}
if description:
cat_entry["description"] = description
categories.append(cat_entry)
return json.dumps({
"success": True,
"categories": categories,
"hint": "If a category is relevant to your task, use skills_list with that category to see available skills"
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"categories": categories,
"hint": "If a category is relevant to your task, use skills_list with that category to see available skills",
},
ensure_ascii=False,
)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, ensure_ascii=False)
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
def skills_list(category: str = None, task_id: str = None) -> str:
"""
List all available skills (progressive disclosure tier 1 - minimal metadata).
Returns only name + description to minimize token usage. Use skill_view() to
Returns only name + description to minimize token usage. Use skill_view() to
load full content, tags, related files, etc.
Args:
category: Optional category filter (e.g., "mlops")
task_id: Optional task identifier (unused, for API consistency)
Returns:
JSON string with minimal skill info: name, description, category
"""
try:
if not SKILLS_DIR.exists():
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
return json.dumps({
"success": True,
"skills": [],
"categories": [],
"message": "No skills found. Skills directory created at ~/.hermes/skills/"
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"skills": [],
"categories": [],
"message": "No skills found. Skills directory created at ~/.hermes/skills/",
},
ensure_ascii=False,
)
# Find all skills
all_skills = _find_all_skills()
if not all_skills:
return json.dumps({
"success": True,
"skills": [],
"categories": [],
"message": "No skills found in skills/ directory."
}, ensure_ascii=False)
return json.dumps(
{"success": True, "skills": [], "categories": [], "message": "No skills found in skills/ directory."},
ensure_ascii=False,
)
# Filter by category if specified
if category:
all_skills = [s for s in all_skills if s.get("category") == category]
# Sort by category then name
all_skills.sort(key=lambda s: (s.get("category") or "", s["name"]))
# Extract unique categories
categories = sorted(set(s.get("category") for s in all_skills if s.get("category")))
return json.dumps({
"success": True,
"skills": all_skills,
"categories": categories,
"count": len(all_skills),
"hint": "Use skill_view(name) to see full content, tags, and linked files"
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"skills": all_skills,
"categories": categories,
"count": len(all_skills),
"hint": "Use skill_view(name) to see full content, tags, and linked files",
},
ensure_ascii=False,
)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, ensure_ascii=False)
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
"""
View the content of a skill or a specific file within a skill directory.
Args:
name: Name or path of the skill (e.g., "axolotl" or "03-fine-tuning/axolotl")
file_path: Optional path to a specific file within the skill (e.g., "references/api.md")
task_id: Optional task identifier (unused, for API consistency)
Returns:
JSON string with skill content or error message
"""
try:
if not SKILLS_DIR.exists():
return json.dumps({
"success": False,
"error": "Skills directory does not exist yet. It will be created on first install."
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "Skills directory does not exist yet. It will be created on first install.",
},
ensure_ascii=False,
)
skill_dir = None
skill_md = None
# Try direct path first (e.g., "mlops/axolotl")
direct_path = SKILLS_DIR / name
if direct_path.is_dir() and (direct_path / "SKILL.md").exists():
skill_dir = direct_path
skill_md = direct_path / "SKILL.md"
elif direct_path.with_suffix('.md').exists():
skill_md = direct_path.with_suffix('.md')
elif direct_path.with_suffix(".md").exists():
skill_md = direct_path.with_suffix(".md")
# Search by directory name
if not skill_md:
for found_skill_md in SKILLS_DIR.rglob("SKILL.md"):
@ -464,64 +467,70 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
skill_dir = found_skill_md.parent
skill_md = found_skill_md
break
# Legacy: flat .md files
if not skill_md:
for found_md in SKILLS_DIR.rglob(f"{name}.md"):
if found_md.name != "SKILL.md":
skill_md = found_md
break
if not skill_md or not skill_md.exists():
# List available skills in error message
all_skills = _find_all_skills()
available = [s["name"] for s in all_skills[:20]] # Limit to 20
return json.dumps({
"success": False,
"error": f"Skill '{name}' not found.",
"available_skills": available,
"hint": "Use skills_list to see all available skills"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": f"Skill '{name}' not found.",
"available_skills": available,
"hint": "Use skills_list to see all available skills",
},
ensure_ascii=False,
)
# If a specific file path is requested, read that instead
if file_path and skill_dir:
# Security: Prevent path traversal attacks
normalized_path = Path(file_path)
if ".." in normalized_path.parts:
return json.dumps({
"success": False,
"error": "Path traversal ('..') is not allowed.",
"hint": "Use a relative path within the skill directory"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "Path traversal ('..') is not allowed.",
"hint": "Use a relative path within the skill directory",
},
ensure_ascii=False,
)
target_file = skill_dir / file_path
# Security: Verify resolved path is still within skill directory
try:
resolved = target_file.resolve()
skill_dir_resolved = skill_dir.resolve()
if not resolved.is_relative_to(skill_dir_resolved):
return json.dumps({
"success": False,
"error": "Path escapes skill directory boundary.",
"hint": "Use a relative path within the skill directory"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "Path escapes skill directory boundary.",
"hint": "Use a relative path within the skill directory",
},
ensure_ascii=False,
)
except (OSError, ValueError):
return json.dumps({
"success": False,
"error": f"Invalid file path: '{file_path}'",
"hint": "Use a valid relative path within the skill directory"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": f"Invalid file path: '{file_path}'",
"hint": "Use a valid relative path within the skill directory",
},
ensure_ascii=False,
)
if not target_file.exists():
# List available files in the skill directory, organized by type
available_files = {
"references": [],
"templates": [],
"assets": [],
"scripts": [],
"other": []
}
available_files = {"references": [], "templates": [], "assets": [], "scripts": [], "other": []}
# Scan for all readable files
for f in skill_dir.rglob("*"):
if f.is_file() and f.name != "SKILL.md":
@ -534,82 +543,85 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
available_files["assets"].append(rel)
elif rel.startswith("scripts/"):
available_files["scripts"].append(rel)
elif f.suffix in ['.md', '.py', '.yaml', '.yml', '.json', '.tex', '.sh']:
elif f.suffix in [".md", ".py", ".yaml", ".yml", ".json", ".tex", ".sh"]:
available_files["other"].append(rel)
# Remove empty categories
available_files = {k: v for k, v in available_files.items() if v}
return json.dumps({
"success": False,
"error": f"File '{file_path}' not found in skill '{name}'.",
"available_files": available_files,
"hint": "Use one of the available file paths listed above"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": f"File '{file_path}' not found in skill '{name}'.",
"available_files": available_files,
"hint": "Use one of the available file paths listed above",
},
ensure_ascii=False,
)
# Read the file content
try:
content = target_file.read_text(encoding='utf-8')
content = target_file.read_text(encoding="utf-8")
except UnicodeDecodeError:
# Binary file - return info about it instead
return json.dumps({
"success": True,
"name": name,
"file": file_path,
"content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]",
"is_binary": True
}, ensure_ascii=False)
return json.dumps({
"success": True,
"name": name,
"file": file_path,
"content": content,
"file_type": target_file.suffix
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"name": name,
"file": file_path,
"content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]",
"is_binary": True,
},
ensure_ascii=False,
)
return json.dumps(
{"success": True, "name": name, "file": file_path, "content": content, "file_type": target_file.suffix},
ensure_ascii=False,
)
# Read the main skill content
content = skill_md.read_text(encoding='utf-8')
content = skill_md.read_text(encoding="utf-8")
frontmatter, body = _parse_frontmatter(content)
# Get reference, template, asset, and script files if this is a directory-based skill
reference_files = []
template_files = []
asset_files = []
script_files = []
if skill_dir:
references_dir = skill_dir / "references"
if references_dir.exists():
reference_files = [str(f.relative_to(skill_dir)) for f in references_dir.glob("*.md")]
templates_dir = skill_dir / "templates"
if templates_dir.exists():
for ext in ['*.md', '*.py', '*.yaml', '*.yml', '*.json', '*.tex', '*.sh']:
for ext in ["*.md", "*.py", "*.yaml", "*.yml", "*.json", "*.tex", "*.sh"]:
template_files.extend([str(f.relative_to(skill_dir)) for f in templates_dir.rglob(ext)])
# assets/ — agentskills.io standard directory for supplementary files
assets_dir = skill_dir / "assets"
if assets_dir.exists():
for f in assets_dir.rglob("*"):
if f.is_file():
asset_files.append(str(f.relative_to(skill_dir)))
scripts_dir = skill_dir / "scripts"
if scripts_dir.exists():
for ext in ['*.py', '*.sh', '*.bash', '*.js', '*.ts', '*.rb']:
for ext in ["*.py", "*.sh", "*.bash", "*.js", "*.ts", "*.rb"]:
script_files.extend([str(f.relative_to(skill_dir)) for f in scripts_dir.glob(ext)])
# Read tags/related_skills with backward compat:
# Check metadata.hermes.* first (agentskills.io convention), fall back to top-level
hermes_meta = {}
metadata = frontmatter.get('metadata')
metadata = frontmatter.get("metadata")
if isinstance(metadata, dict):
hermes_meta = metadata.get('hermes', {}) or {}
tags = _parse_tags(hermes_meta.get('tags') or frontmatter.get('tags', ''))
related_skills = _parse_tags(hermes_meta.get('related_skills') or frontmatter.get('related_skills', ''))
hermes_meta = metadata.get("hermes", {}) or {}
tags = _parse_tags(hermes_meta.get("tags") or frontmatter.get("tags", ""))
related_skills = _parse_tags(hermes_meta.get("related_skills") or frontmatter.get("related_skills", ""))
# Build linked files structure for clear discovery
linked_files = {}
if reference_files:
@ -620,34 +632,33 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
linked_files["assets"] = asset_files
if script_files:
linked_files["scripts"] = script_files
rel_path = str(skill_md.relative_to(SKILLS_DIR))
result = {
"success": True,
"name": frontmatter.get('name', skill_md.stem if not skill_dir else skill_dir.name),
"description": frontmatter.get('description', ''),
"name": frontmatter.get("name", skill_md.stem if not skill_dir else skill_dir.name),
"description": frontmatter.get("description", ""),
"tags": tags,
"related_skills": related_skills,
"content": content,
"path": rel_path,
"linked_files": linked_files if linked_files else None,
"usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'" if linked_files else None
"usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'"
if linked_files
else None,
}
# Surface agentskills.io optional fields when present
if frontmatter.get('compatibility'):
result["compatibility"] = frontmatter['compatibility']
if frontmatter.get("compatibility"):
result["compatibility"] = frontmatter["compatibility"]
if isinstance(metadata, dict):
result["metadata"] = metadata
return json.dumps(result, ensure_ascii=False)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, ensure_ascii=False)
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
# Tool description for model_tools.py
@ -669,7 +680,7 @@ if __name__ == "__main__":
"""Test the skills tool"""
print("🎯 Skills Tool Test")
print("=" * 60)
# Test listing skills
print("\n📋 Listing all skills:")
result = json.loads(skills_list())
@ -678,12 +689,12 @@ if __name__ == "__main__":
print(f"Categories: {result.get('categories', [])}")
print("\nFirst 10 skills:")
for skill in result["skills"][:10]:
cat = f"[{skill['category']}] " if skill.get('category') else ""
refs = f" (+{len(skill['reference_files'])} refs)" if skill.get('reference_files') else ""
cat = f"[{skill['category']}] " if skill.get("category") else ""
refs = f" (+{len(skill['reference_files'])} refs)" if skill.get("reference_files") else ""
print(f"{cat}{skill['name']}: {skill['description'][:60]}...{refs}")
else:
print(f"Error: {result['error']}")
# Test viewing a skill
print("\n📖 Viewing skill 'axolotl':")
result = json.loads(skill_view("axolotl"))
@ -691,11 +702,11 @@ if __name__ == "__main__":
print(f"Name: {result['name']}")
print(f"Description: {result.get('description', 'N/A')[:100]}...")
print(f"Content length: {len(result['content'])} chars")
if result.get('reference_files'):
if result.get("reference_files"):
print(f"Reference files: {result['reference_files']}")
else:
print(f"Error: {result['error']}")
# Test viewing a reference file
print("\n📄 Viewing reference file 'axolotl/references/dataset-formats.md':")
result = json.loads(skill_view("axolotl", "references/dataset-formats.md"))
@ -717,14 +728,9 @@ SKILLS_LIST_SCHEMA = {
"description": "List available skills (name + description). Use skill_view(name) to load full content.",
"parameters": {
"type": "object",
"properties": {
"category": {
"type": "string",
"description": "Optional category filter to narrow results"
}
},
"required": []
}
"properties": {"category": {"type": "string", "description": "Optional category filter to narrow results"}},
"required": [],
},
}
SKILL_VIEW_SCHEMA = {
@ -733,17 +739,14 @@ SKILL_VIEW_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The skill name (use skills_list to see available skills)"
},
"name": {"type": "string", "description": "The skill name (use skills_list to see available skills)"},
"file_path": {
"type": "string",
"description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content."
}
"description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content.",
},
},
"required": ["name"]
}
"required": ["name"],
},
}
registry.register(

View file

@ -26,20 +26,22 @@ Usage:
result = terminal_tool("python server.py", background=True)
"""
import atexit
import json
import logging
import os
import signal
import sys
import time
import threading
import atexit
import shutil
import subprocess
import tempfile
import uuid
import sys
import threading
import time
from pathlib import Path
from typing import Optional, Dict, Any
from typing import Any
from tools.interrupt import (
_interrupt_event, # noqa: F401 — re-exported to environments/local.py
is_interrupted, # noqa: F401 — re-exported
)
from tools.interrupt import set_interrupt as set_interrupt_event # noqa: F401 — re-exported
logger = logging.getLogger(__name__)
@ -49,7 +51,6 @@ logger = logging.getLogger(__name__)
# The terminal tool polls this during command execution so it can kill
# long-running subprocesses immediately instead of blocking until timeout.
# ---------------------------------------------------------------------------
from tools.interrupt import set_interrupt as set_interrupt_event, is_interrupted, _interrupt_event
# Add mini-swe-agent to path if not installed
@ -65,7 +66,6 @@ if mini_swe_path.exists():
# Singularity helpers (scratch dir, SIF cache) now live in tools/environments/singularity.py
from tools.environments.singularity import _get_scratch_dir
# Disk usage warning threshold (in GB)
DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "500"))
@ -73,28 +73,32 @@ DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "5
def _check_disk_usage_warning():
"""Check if total disk usage exceeds warning threshold."""
scratch_dir = _get_scratch_dir()
try:
# Get total size of hermes directories
total_bytes = 0
import glob
for path in glob.glob(str(scratch_dir / "hermes-*")):
for f in Path(path).rglob('*'):
for f in Path(path).rglob("*"):
if f.is_file():
try:
total_bytes += f.stat().st_size
except OSError:
pass
total_gb = total_bytes / (1024 ** 3)
total_gb = total_bytes / (1024**3)
if total_gb > DISK_USAGE_WARNING_THRESHOLD_GB:
logger.warning("Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().",
total_gb, DISK_USAGE_WARNING_THRESHOLD_GB)
logger.warning(
"Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().",
total_gb,
DISK_USAGE_WARNING_THRESHOLD_GB,
)
return True
return False
except Exception as e:
except Exception:
return False
@ -121,59 +125,59 @@ def set_approval_callback(cb):
global _approval_callback
_approval_callback = cb
# =============================================================================
# Dangerous Command Approval System
# =============================================================================
# Dangerous command detection + approval now consolidated in tools/approval.py
from tools.approval import (
detect_dangerous_command as _detect_dangerous_command,
check_dangerous_command as _check_dangerous_command_impl,
load_permanent_allowlist as _load_permanent_allowlist,
DANGEROUS_PATTERNS,
)
def _check_dangerous_command(command: str, env_type: str) -> dict:
"""Delegate to the consolidated approval module, passing the CLI callback."""
return _check_dangerous_command_impl(command, env_type,
approval_callback=_approval_callback)
return _check_dangerous_command_impl(command, env_type, approval_callback=_approval_callback)
def _handle_sudo_failure(output: str, env_type: str) -> str:
"""
Check for sudo failure and add helpful message for messaging contexts.
Returns enhanced output if sudo failed in messaging context, else original.
"""
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
if not is_gateway:
return output
# Check for sudo failure indicators
sudo_failures = [
"sudo: a password is required",
"sudo: no tty present",
"sudo: a terminal is required",
]
for failure in sudo_failures:
if failure in output:
return output + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine."
return (
output
+ "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine."
)
return output
def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
"""
Prompt user for sudo password with timeout.
Returns the password if entered, or empty string if:
- User presses Enter without input (skip)
- Timeout expires (45s default)
- Any error occurs
Only works in interactive mode (HERMES_INTERACTIVE=1).
If a _sudo_password_callback is registered (by the CLI), delegates to it
so the prompt integrates with prompt_toolkit's UI. Otherwise reads
@ -181,7 +185,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
"""
import sys
import time as time_module
# Use the registered callback when available (prompt_toolkit-compatible)
if _sudo_password_callback is not None:
try:
@ -190,13 +194,14 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
return ""
result = {"password": None, "done": False}
def read_password_thread():
"""Read password from /dev/tty with echo disabled."""
tty_fd = None
old_attrs = None
try:
import termios
tty_fd = os.open("/dev/tty", os.O_RDONLY)
old_attrs = termios.tcgetattr(tty_fd)
new_attrs = termios.tcgetattr(tty_fd)
@ -217,6 +222,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
if tty_fd is not None and old_attrs is not None:
try:
import termios as _termios
_termios.tcsetattr(tty_fd, _termios.TCSAFLUSH, old_attrs)
except Exception:
pass
@ -226,11 +232,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
except Exception:
pass
result["done"] = True
try:
os.environ["HERMES_SPINNER_PAUSE"] = "1"
time_module.sleep(0.2)
print()
print("" + "" * 58 + "")
print("│ 🔐 SUDO PASSWORD REQUIRED" + " " * 30 + "")
@ -241,11 +247,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
print("" + "" * 58 + "")
print()
print(" Password (hidden): ", end="", flush=True)
password_thread = threading.Thread(target=read_password_thread, daemon=True)
password_thread.start()
password_thread.join(timeout=timeout_seconds)
if result["done"]:
password = result["password"] or ""
print() # newline after hidden input
@ -262,7 +268,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
print()
sys.stdout.flush()
return ""
except (EOFError, KeyboardInterrupt):
print()
print(" ⏭ Cancelled - continuing without sudo")
@ -281,29 +287,29 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
def _transform_sudo_command(command: str) -> str:
"""
Transform sudo commands to use -S flag if SUDO_PASSWORD is available.
This is a shared helper used by all execution environments to provide
consistent sudo handling across local, SSH, and container environments.
If SUDO_PASSWORD is set (via env, config, or interactive prompt):
'sudo apt install curl' -> password piped via sudo -S
If SUDO_PASSWORD is not set and in interactive mode (HERMES_INTERACTIVE=1):
Prompts user for password with 45s timeout, caches for session.
If SUDO_PASSWORD is not set and NOT interactive:
Command runs as-is (fails gracefully with "sudo: a password is required").
"""
global _cached_sudo_password
import re
# Check if command even contains sudo
if not re.search(r'\bsudo\b', command):
if not re.search(r"\bsudo\b", command):
return command # No sudo in command, return as-is
# Try to get password from: env var -> session cache -> interactive prompt
sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password
if not sudo_password:
# No password configured - check if we're in interactive mode
if os.getenv("HERMES_INTERACTIVE"):
@ -311,30 +317,30 @@ def _transform_sudo_command(command: str) -> str:
sudo_password = _prompt_for_sudo_password(timeout_seconds=45)
if sudo_password:
_cached_sudo_password = sudo_password # Cache for session
if not sudo_password:
return command # No password, let it fail gracefully
def replace_sudo(match):
# Replace 'sudo' with password-piped version
# The -S flag makes sudo read password from stdin
# The -p '' suppresses the password prompt
# Use shlex.quote() to prevent shell injection via password content
import shlex
return f"echo {shlex.quote(sudo_password)} | sudo -S -p ''"
# Match 'sudo' at word boundaries (not 'visudo' or 'sudoers')
# This handles: sudo, sudo -flag, etc.
return re.sub(r'\bsudo\b', replace_sudo, command)
return re.sub(r"\bsudo\b", replace_sudo, command)
# Environment classes now live in tools/environments/
from tools.environments.docker import DockerEnvironment as _DockerEnvironment
from tools.environments.local import LocalEnvironment as _LocalEnvironment
from tools.environments.modal import ModalEnvironment as _ModalEnvironment
from tools.environments.singularity import SingularityEnvironment as _SingularityEnvironment
from tools.environments.ssh import SSHEnvironment as _SSHEnvironment
from tools.environments.docker import DockerEnvironment as _DockerEnvironment
from tools.environments.modal import ModalEnvironment as _ModalEnvironment
# Tool description for LLM
TERMINAL_TOOL_DESCRIPTION = """Execute shell commands on a Linux environment. Filesystem persists between calls.
@ -356,10 +362,10 @@ Do NOT use vim/nano/interactive tools without pty=true — they hang without a p
"""
# Global state for environment lifecycle management
_active_environments: Dict[str, Any] = {}
_last_activity: Dict[str, float] = {}
_active_environments: dict[str, Any] = {}
_last_activity: dict[str, float] = {}
_env_lock = threading.Lock()
_creation_locks: Dict[str, threading.Lock] = {} # Per-task locks for sandbox creation
_creation_locks: dict[str, threading.Lock] = {} # Per-task locks for sandbox creation
_creation_locks_lock = threading.Lock() # Protects _creation_locks dict itself
_cleanup_thread = None
_cleanup_running = False
@ -372,10 +378,10 @@ _cleanup_running = False
#
# This is never exposed to the model -- only infrastructure code calls it.
# Thread-safe because each task_id is unique per rollout.
_task_env_overrides: Dict[str, Dict[str, Any]] = {}
_task_env_overrides: dict[str, dict[str, Any]] = {}
def register_task_env_overrides(task_id: str, overrides: Dict[str, Any]):
def register_task_env_overrides(task_id: str, overrides: dict[str, Any]):
"""
Register environment overrides for a specific task/rollout.
@ -402,13 +408,14 @@ def clear_task_env_overrides(task_id: str):
"""
_task_env_overrides.pop(task_id, None)
# Configuration from environment variables
def _get_env_config() -> Dict[str, Any]:
def _get_env_config() -> dict[str, Any]:
"""Get terminal environment configuration from environment variables."""
# Default image with Python and Node.js for maximum compatibility
default_image = "nikolaik/python-nodejs:python3.11-nodejs20"
env_type = os.getenv("TERMINAL_ENV", "local")
# Default cwd: local uses the host's current directory, everything
# else starts in the user's home (~ resolves to whatever account
# is running inside the container/remote).
@ -416,7 +423,7 @@ def _get_env_config() -> Dict[str, Any]:
default_cwd = os.getcwd()
else:
default_cwd = "~"
# Read TERMINAL_CWD but sanity-check it for container backends.
# If the CWD looks like a host-local path that can't exist inside a
# container/sandbox, fall back to the backend's own default. This
@ -426,9 +433,12 @@ def _get_env_config() -> Dict[str, Any]:
if env_type in ("modal", "docker", "singularity", "daytona") and cwd:
host_prefixes = ("/Users/", "C:\\", "C:/")
if any(cwd.startswith(p) for p in host_prefixes) and cwd != default_cwd:
logger.info("Ignoring TERMINAL_CWD=%r for %s backend "
"(host path won't exist in sandbox). Using %r instead.",
cwd, env_type, default_cwd)
logger.info(
"Ignoring TERMINAL_CWD=%r for %s backend (host path won't exist in sandbox). Using %r instead.",
cwd,
env_type,
default_cwd,
)
cwd = default_cwd
return {
@ -447,19 +457,25 @@ def _get_env_config() -> Dict[str, Any]:
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": float(os.getenv("TERMINAL_CONTAINER_CPU", "1")),
"container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB)
"container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB)
"container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB)
"container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB)
"container_persistent": os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("true", "1", "yes"),
"docker_volumes": json.loads(os.getenv("TERMINAL_DOCKER_VOLUMES", "[]")),
}
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
ssh_config: dict = None, container_config: dict = None,
task_id: str = "default"):
def _create_environment(
env_type: str,
image: str,
cwd: str,
timeout: int,
ssh_config: dict = None,
container_config: dict = None,
task_id: str = "default",
):
"""
Create an execution environment from mini-swe-agent.
Args:
env_type: One of "local", "docker", "singularity", "modal", "daytona", "ssh"
image: Docker/Singularity/Modal image name (ignored for local/ssh)
@ -468,7 +484,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
ssh_config: SSH connection config (for env_type="ssh")
container_config: Resource config for container backends (cpu, memory, disk, persistent)
task_id: Task identifier for environment reuse and snapshot keying
Returns:
Environment instance with execute() method
"""
@ -481,22 +497,32 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
if env_type == "local":
return _LocalEnvironment(cwd=cwd, timeout=timeout)
elif env_type == "docker":
return _DockerEnvironment(
image=image, cwd=cwd, timeout=timeout,
cpu=cpu, memory=memory, disk=disk,
persistent_filesystem=persistent, task_id=task_id,
image=image,
cwd=cwd,
timeout=timeout,
cpu=cpu,
memory=memory,
disk=disk,
persistent_filesystem=persistent,
task_id=task_id,
volumes=volumes,
)
elif env_type == "singularity":
return _SingularityEnvironment(
image=image, cwd=cwd, timeout=timeout,
cpu=cpu, memory=memory, disk=disk,
persistent_filesystem=persistent, task_id=task_id,
image=image,
cwd=cwd,
timeout=timeout,
cpu=cpu,
memory=memory,
disk=disk,
persistent_filesystem=persistent,
task_id=task_id,
)
elif env_type == "modal":
sandbox_kwargs = {}
if cpu > 0:
@ -505,20 +531,29 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
sandbox_kwargs["memory"] = memory
if disk > 0:
sandbox_kwargs["ephemeral_disk"] = disk
return _ModalEnvironment(
image=image, cwd=cwd, timeout=timeout,
image=image,
cwd=cwd,
timeout=timeout,
modal_sandbox_kwargs=sandbox_kwargs,
persistent_filesystem=persistent, task_id=task_id,
persistent_filesystem=persistent,
task_id=task_id,
)
elif env_type == "daytona":
# Lazy import so daytona SDK is only required when backend is selected.
from tools.environments.daytona import DaytonaEnvironment as _DaytonaEnvironment
return _DaytonaEnvironment(
image=image, cwd=cwd, timeout=timeout,
cpu=int(cpu), memory=memory, disk=disk,
persistent_filesystem=persistent, task_id=task_id,
image=image,
cwd=cwd,
timeout=timeout,
cpu=int(cpu),
memory=memory,
disk=disk,
persistent_filesystem=persistent,
task_id=task_id,
)
elif env_type == "ssh":
@ -534,7 +569,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
)
else:
raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'")
raise ValueError(
f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'"
)
def _cleanup_inactive_envs(lifetime_seconds: int = 300):
@ -547,6 +584,7 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300):
# background processes (their _last_activity gets refreshed to keep them alive).
try:
from tools.process_registry import process_registry
for task_id in list(_last_activity.keys()):
if process_registry.has_active_processes(task_id):
_last_activity[task_id] = current_time # Keep sandbox alive
@ -579,16 +617,17 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300):
# ShellFileOperations from referencing a dead sandbox)
try:
from tools.file_tools import clear_file_ops_cache
clear_file_ops_cache(task_id)
except ImportError:
pass
try:
if hasattr(env, 'cleanup'):
if hasattr(env, "cleanup"):
env.cleanup()
elif hasattr(env, 'stop'):
elif hasattr(env, "stop"):
env.stop()
elif hasattr(env, 'terminate'):
elif hasattr(env, "terminate"):
env.terminate()
logger.info("Cleaned up inactive environment for task: %s", task_id)
@ -640,27 +679,28 @@ def _stop_cleanup_thread():
pass
def get_active_environments_info() -> Dict[str, Any]:
def get_active_environments_info() -> dict[str, Any]:
"""Get information about currently active environments."""
info = {
"count": len(_active_environments),
"task_ids": list(_active_environments.keys()),
"workdirs": {},
}
# Calculate total disk usage (per-task to avoid double-counting)
total_size = 0
for task_id in _active_environments.keys():
for task_id in _active_environments:
scratch_dir = _get_scratch_dir()
pattern = f"hermes-*{task_id[:8]}*"
import glob
for path in glob.glob(str(scratch_dir / pattern)):
try:
size = sum(f.stat().st_size for f in Path(path).rglob('*') if f.is_file())
size = sum(f.stat().st_size for f in Path(path).rglob("*") if f.is_file())
total_size += size
except OSError:
pass
info["total_disk_usage_mb"] = round(total_size / (1024 * 1024), 2)
return info
@ -668,27 +708,28 @@ def get_active_environments_info() -> Dict[str, Any]:
def cleanup_all_environments():
"""Clean up ALL active environments. Use with caution."""
global _active_environments, _last_activity
task_ids = list(_active_environments.keys())
cleaned = 0
for task_id in task_ids:
try:
cleanup_vm(task_id)
cleaned += 1
except Exception as e:
logger.error("Error cleaning %s: %s", task_id, e, exc_info=True)
# Also clean any orphaned directories
scratch_dir = _get_scratch_dir()
import glob
for path in glob.glob(str(scratch_dir / "hermes-*")):
try:
shutil.rmtree(path, ignore_errors=True)
logger.info("Removed orphaned: %s", path)
except OSError:
pass
if cleaned > 0:
logger.info("Cleaned %d environments", cleaned)
return cleaned
@ -713,6 +754,7 @@ def cleanup_vm(task_id: str):
# Invalidate stale file_ops cache entry
try:
from tools.file_tools import clear_file_ops_cache
clear_file_ops_cache(task_id)
except ImportError:
pass
@ -721,11 +763,11 @@ def cleanup_vm(task_id: str):
return
try:
if hasattr(env, 'cleanup'):
if hasattr(env, "cleanup"):
env.cleanup()
elif hasattr(env, 'stop'):
elif hasattr(env, "stop"):
env.stop()
elif hasattr(env, 'terminate'):
elif hasattr(env, "terminate"):
env.terminate()
logger.info("Manually cleaned up environment for task: %s", task_id)
@ -746,17 +788,18 @@ def _atexit_cleanup():
logger.info("Shutting down %d remaining sandbox(es)...", count)
cleanup_all_environments()
atexit.register(_atexit_cleanup)
def terminal_tool(
command: str,
background: bool = False,
timeout: Optional[int] = None,
task_id: Optional[str] = None,
timeout: int | None = None,
task_id: str | None = None,
force: bool = False,
workdir: Optional[str] = None,
check_interval: Optional[int] = None,
workdir: str | None = None,
check_interval: int | None = None,
pty: bool = False,
) -> str:
"""
@ -784,7 +827,7 @@ def terminal_tool(
# With custom timeout
>>> result = terminal_tool(command="long_task.sh", timeout=300)
# Force run after user confirmation
# Note: force parameter is internal only, not exposed to model API
"""
@ -801,7 +844,7 @@ def terminal_tool(
# Check per-task overrides (set by environments like TerminalBench2Env)
# before falling back to global env var config
overrides = _task_env_overrides.get(effective_task_id, {})
# Select image based on env type, with per-task override support
if env_type == "docker":
image = overrides.get("docker_image") or config["docker_image"]
@ -882,12 +925,15 @@ def terminal_tool(
task_id=effective_task_id,
)
except ImportError as e:
return json.dumps({
"output": "",
"exit_code": -1,
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
"status": "disabled"
}, ensure_ascii=False)
return json.dumps(
{
"output": "",
"exit_code": -1,
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
"status": "disabled",
},
ensure_ascii=False,
)
with _env_lock:
_active_environments[effective_task_id] = new_env
@ -902,27 +948,33 @@ def terminal_tool(
if not approval["approved"]:
# Check if this is an approval_required (gateway ask mode)
if approval.get("status") == "approval_required":
return json.dumps({
"output": "",
"exit_code": -1,
"error": approval.get("message", "Waiting for user approval"),
"status": "approval_required",
"command": approval.get("command", command),
"description": approval.get("description", "dangerous command"),
"pattern_key": approval.get("pattern_key", ""),
}, ensure_ascii=False)
return json.dumps(
{
"output": "",
"exit_code": -1,
"error": approval.get("message", "Waiting for user approval"),
"status": "approval_required",
"command": approval.get("command", command),
"description": approval.get("description", "dangerous command"),
"pattern_key": approval.get("pattern_key", ""),
},
ensure_ascii=False,
)
# Command was blocked - include the pattern category so the caller knows why
desc = approval.get("description", "potentially dangerous operation")
fallback_msg = (
f"Command denied: matches '{desc}' pattern. "
"Use the approval prompt to allow it, or rephrase the command."
)
return json.dumps({
"output": "",
"exit_code": -1,
"error": approval.get("message", fallback_msg),
"status": "blocked"
}, ensure_ascii=False)
return json.dumps(
{
"output": "",
"exit_code": -1,
"error": approval.get("message", fallback_msg),
"status": "blocked",
},
ensure_ascii=False,
)
# Prepare command for execution
if background:
@ -940,7 +992,7 @@ def terminal_tool(
cwd=effective_cwd,
task_id=effective_task_id,
session_key=session_key,
env_vars=env.env if hasattr(env, 'env') else None,
env_vars=env.env if hasattr(env, "env") else None,
use_pty=pty,
)
else:
@ -964,38 +1016,36 @@ def terminal_tool(
max_timeout = effective_timeout
if timeout and timeout > max_timeout:
result_data["timeout_note"] = (
f"Requested timeout {timeout}s was clamped to "
f"configured limit of {max_timeout}s"
f"Requested timeout {timeout}s was clamped to configured limit of {max_timeout}s"
)
# Register check_interval watcher (gateway picks this up after agent run)
if check_interval and background:
effective_interval = max(30, check_interval)
if check_interval < 30:
result_data["check_interval_note"] = (
f"Requested {check_interval}s raised to minimum 30s"
)
process_registry.pending_watchers.append({
"session_id": proc_session.id,
"check_interval": effective_interval,
"session_key": session_key,
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
})
result_data["check_interval_note"] = f"Requested {check_interval}s raised to minimum 30s"
process_registry.pending_watchers.append(
{
"session_id": proc_session.id,
"check_interval": effective_interval,
"session_key": session_key,
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
}
)
return json.dumps(result_data, ensure_ascii=False)
except Exception as e:
return json.dumps({
"output": "",
"exit_code": -1,
"error": f"Failed to start background process: {str(e)}"
}, ensure_ascii=False)
return json.dumps(
{"output": "", "exit_code": -1, "error": f"Failed to start background process: {str(e)}"},
ensure_ascii=False,
)
else:
# Run foreground command with retry logic
max_retries = 3
retry_count = 0
result = None
while retry_count <= max_retries:
try:
execute_kwargs = {"timeout": effective_timeout}
@ -1005,39 +1055,61 @@ def terminal_tool(
except Exception as e:
error_str = str(e).lower()
if "timeout" in error_str:
return json.dumps({
"output": "",
"exit_code": 124,
"error": f"Command timed out after {effective_timeout} seconds"
}, ensure_ascii=False)
return json.dumps(
{
"output": "",
"exit_code": 124,
"error": f"Command timed out after {effective_timeout} seconds",
},
ensure_ascii=False,
)
# Retry on transient errors
if retry_count < max_retries:
retry_count += 1
wait_time = 2 ** retry_count
logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
wait_time = 2**retry_count
logger.warning(
"Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
wait_time,
retry_count,
max_retries,
command[:200],
type(e).__name__,
e,
effective_task_id,
env_type,
)
time.sleep(wait_time)
continue
logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
return json.dumps({
"output": "",
"exit_code": -1,
"error": f"Command execution failed: {type(e).__name__}: {str(e)}"
}, ensure_ascii=False)
logger.error(
"Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
max_retries,
command[:200],
type(e).__name__,
e,
effective_task_id,
env_type,
)
return json.dumps(
{
"output": "",
"exit_code": -1,
"error": f"Command execution failed: {type(e).__name__}: {str(e)}",
},
ensure_ascii=False,
)
# Got a result
break
# Extract output
output = result.get("output", "")
returncode = result.get("returncode", 0)
# Add helpful message for sudo failures in messaging context
output = _handle_sudo_failure(output, env_type)
# Truncate output if too long, keeping both head and tail
MAX_OUTPUT_CHARS = 50000
if len(output) > MAX_OUTPUT_CHARS:
@ -1045,65 +1117,56 @@ def terminal_tool(
tail_chars = MAX_OUTPUT_CHARS - head_chars # 60% tail (most recent/relevant output)
omitted = len(output) - head_chars - tail_chars
truncated_notice = (
f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted "
f"out of {len(output)} total] ...\n\n"
f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted out of {len(output)} total] ...\n\n"
)
output = output[:head_chars] + truncated_notice + output[-tail_chars:]
# Redact secrets from command output (catches env/printenv leaking keys)
from agent.redact import redact_sensitive_text
output = redact_sensitive_text(output.strip()) if output else ""
return json.dumps({
"output": output,
"exit_code": returncode,
"error": None
}, ensure_ascii=False)
return json.dumps({"output": output, "exit_code": returncode, "error": None}, ensure_ascii=False)
except Exception as e:
return json.dumps({
"output": "",
"exit_code": -1,
"error": f"Failed to execute command: {str(e)}",
"status": "error"
}, ensure_ascii=False)
return json.dumps(
{"output": "", "exit_code": -1, "error": f"Failed to execute command: {str(e)}", "status": "error"},
ensure_ascii=False,
)
def check_terminal_requirements() -> bool:
"""Check if all requirements for the terminal tool are met."""
config = _get_env_config()
env_type = config["env_type"]
try:
if env_type == "local":
from minisweagent.environments.local import LocalEnvironment
return True
elif env_type == "docker":
from minisweagent.environments.docker import DockerEnvironment
# Check if docker is available
import subprocess
result = subprocess.run(["docker", "version"], capture_output=True, timeout=5)
return result.returncode == 0
elif env_type == "singularity":
from minisweagent.environments.singularity import SingularityEnvironment
import shutil
# Check if singularity/apptainer is available
import subprocess
import shutil
executable = shutil.which("apptainer") or shutil.which("singularity")
if executable:
result = subprocess.run([executable, "--version"], capture_output=True, timeout=5)
return result.returncode == 0
return False
elif env_type == "ssh":
from tools.environments.ssh import SSHEnvironment
# Check that host and user are configured
return bool(config.get("ssh_host")) and bool(config.get("ssh_user"))
elif env_type == "modal":
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
# Check for modal token
return os.getenv("MODAL_TOKEN_ID") is not None or Path.home().joinpath(".modal.toml").exists()
elif env_type == "daytona":
from daytona import Daytona
return os.getenv("DAYTONA_API_KEY") is not None
else:
return False
@ -1116,9 +1179,9 @@ if __name__ == "__main__":
# Simple test when run directly
print("Terminal Tool Module (mini-swe-agent backend)")
print("=" * 50)
config = _get_env_config()
print(f"\nCurrent Configuration:")
print("\nCurrent Configuration:")
print(f" Environment type: {config['env_type']}")
print(f" Docker image: {config['docker_image']}")
print(f" Modal image: {config['modal_image']}")
@ -1165,37 +1228,34 @@ TERMINAL_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The command to execute on the VM"
},
"command": {"type": "string", "description": "The command to execute on the VM"},
"background": {
"type": "boolean",
"description": "ONLY for servers/watchers that never exit. For scripts, builds, installs — use foreground with timeout instead (it returns instantly when done).",
"default": False
"default": False,
},
"timeout": {
"type": "integer",
"description": "Max seconds to wait (default: 180). Returns INSTANTLY when command finishes — set high for long tasks, you won't wait unnecessarily.",
"minimum": 1
"minimum": 1,
},
"workdir": {
"type": "string",
"description": "Working directory for this command (absolute path). Defaults to the session working directory."
"description": "Working directory for this command (absolute path). Defaults to the session working directory.",
},
"check_interval": {
"type": "integer",
"description": "Seconds between automatic status checks for background processes (gateway/messaging only, minimum 30). When set, I'll proactively report progress.",
"minimum": 30
"minimum": 30,
},
"pty": {
"type": "boolean",
"description": "Run in pseudo-terminal (PTY) mode for interactive CLI tools like Codex, Claude Code, or Python REPL. Only works with local and SSH backends. Default: false.",
"default": False
}
"default": False,
},
},
"required": ["command"]
}
"required": ["command"],
},
}

View file

@ -15,8 +15,7 @@ Design:
"""
import json
from typing import Dict, Any, List, Optional
from typing import Any
# Valid status values for todo items
VALID_STATUSES = {"pending", "in_progress", "completed", "cancelled"}
@ -33,9 +32,9 @@ class TodoStore:
"""
def __init__(self):
self._items: List[Dict[str, str]] = []
self._items: list[dict[str, str]] = []
def write(self, todos: List[Dict[str, Any]], merge: bool = False) -> List[Dict[str, str]]:
def write(self, todos: list[dict[str, Any]], merge: bool = False) -> list[dict[str, str]]:
"""
Write todos. Returns the full current list after writing.
@ -79,7 +78,7 @@ class TodoStore:
self._items = rebuilt
return self.read()
def read(self) -> List[Dict[str, str]]:
def read(self) -> list[dict[str, str]]:
"""Return a copy of the current list."""
return [item.copy() for item in self._items]
@ -87,7 +86,7 @@ class TodoStore:
"""Check if there are any items in the list."""
return len(self._items) > 0
def format_for_injection(self) -> Optional[str]:
def format_for_injection(self) -> str | None:
"""
Render the todo list for post-compression injection.
@ -113,7 +112,7 @@ class TodoStore:
return "\n".join(lines)
@staticmethod
def _validate(item: Dict[str, Any]) -> Dict[str, str]:
def _validate(item: dict[str, Any]) -> dict[str, str]:
"""
Validate and normalize a todo item.
@ -136,9 +135,9 @@ class TodoStore:
def todo_tool(
todos: Optional[List[Dict[str, Any]]] = None,
todos: list[dict[str, Any]] | None = None,
merge: bool = False,
store: Optional[TodoStore] = None,
store: TodoStore | None = None,
) -> str:
"""
Single entry point for the todo tool. Reads or writes depending on params.
@ -165,16 +164,19 @@ def todo_tool(
completed = sum(1 for i in items if i["status"] == "completed")
cancelled = sum(1 for i in items if i["status"] == "cancelled")
return json.dumps({
"todos": items,
"summary": {
"total": len(items),
"pending": pending,
"in_progress": in_progress,
"completed": completed,
"cancelled": cancelled,
return json.dumps(
{
"todos": items,
"summary": {
"total": len(items),
"pending": pending,
"in_progress": in_progress,
"completed": completed,
"cancelled": cancelled,
},
},
}, ensure_ascii=False)
ensure_ascii=False,
)
def check_todo_requirements() -> bool:
@ -214,34 +216,27 @@ TODO_SCHEMA = {
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique item identifier"
},
"content": {
"type": "string",
"description": "Task description"
},
"id": {"type": "string", "description": "Unique item identifier"},
"content": {"type": "string", "description": "Task description"},
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed", "cancelled"],
"description": "Current status"
}
"description": "Current status",
},
},
"required": ["id", "content", "status"]
}
"required": ["id", "content", "status"],
},
},
"merge": {
"type": "boolean",
"description": (
"true: update existing items by id, add new ones. "
"false (default): replace the entire list."
"true: update existing items by id, add new ones. false (default): replace the entire list."
),
"default": False
}
"default": False,
},
},
"required": []
}
"required": [],
},
}
@ -253,6 +248,7 @@ registry.register(
toolset="todo",
schema=TODO_SCHEMA,
handler=lambda args, **kw: todo_tool(
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")),
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")
),
check_fn=check_todo_requirements,
)

View file

@ -24,7 +24,7 @@ Usage:
import logging
import os
from pathlib import Path
from typing import Optional, Dict, Any
from typing import Any
logger = logging.getLogger(__name__)
@ -39,7 +39,7 @@ SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm",
MAX_FILE_SIZE = 25 * 1024 * 1024
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
def transcribe_audio(file_path: str, model: str | None = None) -> dict[str, Any]:
"""
Transcribe an audio file using OpenAI's Whisper API.
@ -65,7 +65,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
}
audio_path = Path(file_path)
# Validate file exists
if not audio_path.exists():
return {
@ -73,14 +73,14 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
"transcript": "",
"error": f"Audio file not found: {file_path}",
}
if not audio_path.is_file():
return {
"success": False,
"transcript": "",
"error": f"Path is not a file: {file_path}",
}
# Validate file extension
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
return {
@ -88,7 +88,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
"transcript": "",
"error": f"Unsupported file format: {audio_path.suffix}. Supported formats: {', '.join(sorted(SUPPORTED_FORMATS))}",
}
# Validate file size
try:
file_size = audio_path.stat().st_size
@ -96,7 +96,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
return {
"success": False,
"transcript": "",
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024)}MB)",
"error": f"File too large: {file_size / (1024 * 1024):.1f}MB (max {MAX_FILE_SIZE / (1024 * 1024)}MB)",
}
except OSError as e:
logger.error("Failed to get file size for %s: %s", file_path, e, exc_info=True)
@ -111,7 +111,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
model = DEFAULT_STT_MODEL
try:
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1")

View file

@ -27,9 +27,8 @@ import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
@ -38,12 +37,14 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
try:
import edge_tts
_HAS_EDGE_TTS = True
except ImportError:
_HAS_EDGE_TTS = False
try:
from elevenlabs.client import ElevenLabs
_HAS_ELEVENLABS = True
except ImportError:
_HAS_ELEVENLABS = False
@ -51,6 +52,7 @@ except ImportError:
# openai is a core dependency, but guard anyway
try:
from openai import OpenAI as OpenAIClient
_HAS_OPENAI = True
except ImportError:
_HAS_OPENAI = False
@ -72,7 +74,7 @@ MAX_TEXT_LENGTH = 4000
# ===========================================================================
# Config loader -- reads tts: section from ~/.hermes/config.yaml
# ===========================================================================
def _load_tts_config() -> Dict[str, Any]:
def _load_tts_config() -> dict[str, Any]:
"""
Load TTS configuration from ~/.hermes/config.yaml.
@ -81,13 +83,14 @@ def _load_tts_config() -> Dict[str, Any]:
"""
try:
from hermes_cli.config import load_config
config = load_config()
return config.get("tts", {})
except Exception:
return {}
def _get_provider(tts_config: Dict[str, Any]) -> str:
def _get_provider(tts_config: dict[str, Any]) -> str:
"""Get the configured TTS provider name."""
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
@ -100,7 +103,7 @@ def _has_ffmpeg() -> bool:
return shutil.which("ffmpeg") is not None
def _convert_to_opus(mp3_path: str) -> Optional[str]:
def _convert_to_opus(mp3_path: str) -> str | None:
"""
Convert an MP3 file to OGG Opus format for Telegram voice bubbles.
@ -116,9 +119,9 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]:
ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg"
try:
subprocess.run(
["ffmpeg", "-i", mp3_path, "-acodec", "libopus",
"-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
capture_output=True, timeout=30,
["ffmpeg", "-i", mp3_path, "-acodec", "libopus", "-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
capture_output=True,
timeout=30,
)
if os.path.exists(ogg_path) and os.path.getsize(ogg_path) > 0:
return ogg_path
@ -130,7 +133,7 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]:
# ===========================================================================
# Provider: Edge TTS (free)
# ===========================================================================
async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
async def _generate_edge_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
"""
Generate audio using Edge TTS.
@ -153,7 +156,7 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str,
# ===========================================================================
# Provider: ElevenLabs (premium)
# ===========================================================================
def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
def _generate_elevenlabs(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
"""
Generate audio using ElevenLabs.
@ -198,7 +201,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]
# ===========================================================================
# Provider: OpenAI TTS
# ===========================================================================
def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
def _generate_openai_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
"""
Generate audio using OpenAI TTS.
@ -241,7 +244,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
# ===========================================================================
def text_to_speech_tool(
text: str,
output_path: Optional[str] = None,
output_path: str | None = None,
) -> str:
"""
Convert text to speech audio.
@ -276,7 +279,7 @@ def text_to_speech_tool(
# produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3
# and needs ffmpeg for conversion.
platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower()
want_opus = (platform == "telegram")
want_opus = platform == "telegram"
# Determine output path
if output_path:
@ -300,47 +303,48 @@ def text_to_speech_tool(
# Generate audio with the configured provider
if provider == "elevenlabs":
if not _HAS_ELEVENLABS:
return json.dumps({
"success": False,
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs",
},
ensure_ascii=False,
)
logger.info("Generating speech with ElevenLabs...")
_generate_elevenlabs(text, file_str, tts_config)
elif provider == "openai":
if not _HAS_OPENAI:
return json.dumps({
"success": False,
"error": "OpenAI provider selected but 'openai' package not installed."
}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "OpenAI provider selected but 'openai' package not installed."},
ensure_ascii=False,
)
logger.info("Generating speech with OpenAI TTS...")
_generate_openai_tts(text, file_str, tts_config)
else:
# Default: Edge TTS (free)
if not _HAS_EDGE_TTS:
return json.dumps({
"success": False,
"error": "Edge TTS not available. Run: pip install edge-tts"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "Edge TTS not available. Run: pip install edge-tts"}, ensure_ascii=False
)
logger.info("Generating speech with Edge TTS...")
# Edge TTS is async, run it
try:
loop = asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
pool.submit(
lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))
).result(timeout=60)
pool.submit(lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))).result(timeout=60)
except RuntimeError:
asyncio.run(_generate_edge_tts(text, file_str, tts_config))
# Check the file was actually created
if not os.path.exists(file_str) or os.path.getsize(file_str) == 0:
return json.dumps({
"success": False,
"error": f"TTS generation produced no output (provider: {provider})"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": f"TTS generation produced no output (provider: {provider})"},
ensure_ascii=False,
)
# Try Opus conversion for Telegram compatibility (Edge TTS only outputs MP3)
voice_compatible = False
@ -361,13 +365,16 @@ def text_to_speech_tool(
if voice_compatible:
media_tag = f"[[audio_as_voice]]\n{media_tag}"
return json.dumps({
"success": True,
"file_path": file_str,
"media_tag": media_tag,
"provider": provider,
"voice_compatible": voice_compatible,
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"file_path": file_str,
"media_tag": media_tag,
"provider": provider,
"voice_compatible": voice_compatible,
},
ensure_ascii=False,
)
except Exception as e:
error_msg = f"TTS generation failed ({provider}): {e}"
@ -404,7 +411,7 @@ if __name__ == "__main__":
print("🔊 Text-to-Speech Tool Module")
print("=" * 50)
print(f"\nProvider availability:")
print("\nProvider availability:")
print(f" Edge TTS: {'✅ installed' if _HAS_EDGE_TTS else '❌ not installed (pip install edge-tts)'}")
print(f" ElevenLabs: {'✅ installed' if _HAS_ELEVENLABS else '❌ not installed (pip install elevenlabs)'}")
print(f" API Key: {'✅ set' if os.getenv('ELEVENLABS_API_KEY') else '❌ not set'}")
@ -429,25 +436,20 @@ TTS_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to convert to speech. Keep under 4000 characters."
},
"text": {"type": "string", "description": "The text to convert to speech. Keep under 4000 characters."},
"output_path": {
"type": "string",
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
}
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3",
},
},
"required": ["text"]
}
"required": ["text"],
},
}
registry.register(
name="text_to_speech",
toolset="tts",
schema=TTS_SCHEMA,
handler=lambda args, **kw: text_to_speech_tool(
text=args.get("text", ""),
output_path=args.get("output_path")),
handler=lambda args, **kw: text_to_speech_tool(text=args.get("text", ""), output_path=args.get("output_path")),
check_fn=check_tts_requirements,
)

View file

@ -19,7 +19,7 @@ Features:
Usage:
from vision_tools import vision_analyze_tool
import asyncio
# Analyze an image
result = await vision_analyze_tool(
image_url="https://example.com/image.jpg",
@ -33,11 +33,14 @@ import json
import logging
import os
import uuid
from collections.abc import Awaitable
from pathlib import Path
from typing import Any, Awaitable, Dict, Optional
from typing import Any
from urllib.parse import urlparse
import httpx
from openai import AsyncOpenAI
from agent.auxiliary_client import get_vision_auxiliary_client
from tools.debug_helpers import DebugSession
@ -55,7 +58,7 @@ if _aux_sync_client is not None:
_async_kwargs["default_headers"] = {
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
"X-OpenRouter-Title": "Hermes Agent",
"X-OpenRouter-Categories": "productivity,cli-agent",
"X-OpenRouter-Categories": "productivity,cli-agent",
}
_aux_async_client = AsyncOpenAI(**_async_kwargs)
@ -65,10 +68,10 @@ _debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG")
def _validate_image_url(url: str) -> bool:
"""
Basic validation of image URL format.
Args:
url (str): The URL to validate
Returns:
bool: True if URL appears to be valid, False otherwise
"""
@ -91,23 +94,22 @@ def _validate_image_url(url: str) -> bool:
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
"""
Download an image from a URL to a local destination (async) with retry logic.
Args:
image_url (str): The URL of the image to download
destination (Path): The path where the image should be saved
max_retries (int): Maximum number of retry attempts (default: 3)
Returns:
Path: The path to the downloaded image
Raises:
Exception: If download fails after all retries
"""
import asyncio
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
last_error = None
for attempt in range(max_retries):
try:
@ -122,10 +124,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
},
)
response.raise_for_status()
# Save the image content
destination.write_bytes(response.content)
return destination
except Exception as e:
last_error = e
@ -141,56 +143,56 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
str(e)[:100],
exc_info=True,
)
raise last_error
def _determine_mime_type(image_path: Path) -> str:
"""
Determine the MIME type of an image based on its file extension.
Args:
image_path (Path): Path to the image file
Returns:
str: The MIME type (defaults to image/jpeg if unknown)
"""
extension = image_path.suffix.lower()
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
'.svg': 'image/svg+xml'
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".bmp": "image/bmp",
".webp": "image/webp",
".svg": "image/svg+xml",
}
return mime_types.get(extension, 'image/jpeg')
return mime_types.get(extension, "image/jpeg")
def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str:
def _image_to_base64_data_url(image_path: Path, mime_type: str | None = None) -> str:
"""
Convert an image file to a base64-encoded data URL.
Args:
image_path (Path): Path to the image file
mime_type (Optional[str]): MIME type of the image (auto-detected if None)
Returns:
str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...")
"""
# Read the image as bytes
data = image_path.read_bytes()
# Encode to base64
encoded = base64.b64encode(data).decode("ascii")
# Determine MIME type
mime = mime_type or _determine_mime_type(image_path)
# Create data URL
data_url = f"data:{mime};base64,{encoded}"
return data_url
@ -201,31 +203,31 @@ async def vision_analyze_tool(
) -> str:
"""
Analyze an image from a URL or local file path using vision AI.
This tool accepts either an HTTP/HTTPS URL or a local file path. For URLs,
it downloads the image first. In both cases, the image is converted to base64
and processed using Gemini 3 Flash Preview via OpenRouter API.
The user_prompt parameter is expected to be pre-formatted by the calling
function (typically model_tools.py) to include both full description
requests and specific questions.
Args:
image_url (str): The URL or local file path of the image to analyze.
Accepts http://, https:// URLs or absolute/relative file paths.
user_prompt (str): The pre-formatted prompt for the vision model
model (str): The vision model to use (default: google/gemini-3-flash-preview)
Returns:
str: JSON string containing the analysis results with the following structure:
{
"success": bool,
"analysis": str (defaults to error message if None)
}
Raises:
Exception: If download fails, analysis fails, or API key is not set
Note:
- For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up
- For local file paths, the file is used directly and NOT deleted
@ -235,36 +237,41 @@ async def vision_analyze_tool(
"parameters": {
"image_url": image_url,
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
"model": model
"model": model,
},
"error": None,
"success": False,
"analysis_length": 0,
"model_used": model,
"image_size_bytes": 0
"image_size_bytes": 0,
}
temp_image_path = None
# Track whether we should clean up the file after processing.
# Local files (e.g. from the image cache) should NOT be deleted.
should_cleanup = True
try:
from tools.interrupt import is_interrupted
if is_interrupted():
return json.dumps({"success": False, "error": "Interrupted"})
logger.info("Analyzing image: %s", image_url[:60])
logger.info("User prompt: %s", user_prompt[:100])
# Check auxiliary vision client availability
if _aux_async_client is None or DEFAULT_VISION_MODEL is None:
return json.dumps({
"success": False,
"analysis": "Vision analysis unavailable: no auxiliary vision model configured. "
"Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools."
}, indent=2, ensure_ascii=False)
return json.dumps(
{
"success": False,
"analysis": "Vision analysis unavailable: no auxiliary vision model configured. "
"Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.",
},
indent=2,
ensure_ascii=False,
)
# Determine if this is a local file path or a remote URL
local_path = Path(image_url)
if local_path.is_file():
@ -280,50 +287,41 @@ async def vision_analyze_tool(
await _download_image(image_url, temp_image_path)
should_cleanup = True
else:
raise ValueError(
"Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path."
)
raise ValueError("Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path.")
# Get image file size for logging
image_size_bytes = temp_image_path.stat().st_size
image_size_kb = image_size_bytes / 1024
logger.info("Image ready (%.1f KB)", image_size_kb)
# Convert image to base64 data URL
logger.info("Converting image to base64...")
image_data_url = _image_to_base64_data_url(temp_image_path)
# Calculate size in KB for better readability
data_size_kb = len(image_data_url) / 1024
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
debug_call_data["image_size_bytes"] = image_size_bytes
# Use the prompt as provided (model_tools.py now handles full description formatting)
comprehensive_prompt = user_prompt
# Prepare the message with base64-encoded image
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": comprehensive_prompt
},
{
"type": "image_url",
"image_url": {
"url": image_data_url
}
}
]
{"type": "text", "text": comprehensive_prompt},
{"type": "image_url", "image_url": {"url": image_data_url}},
],
}
]
logger.info("Processing image with %s...", model)
# Call the vision API
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
_extra = get_auxiliary_extra_body()
response = await _aux_async_client.chat.completions.create(
model=model,
@ -332,44 +330,44 @@ async def vision_analyze_tool(
**auxiliary_max_tokens_param(2000),
**({} if not _extra else {"extra_body": _extra}),
)
# Extract the analysis
analysis = response.choices[0].message.content.strip()
analysis_length = len(analysis)
logger.info("Image analysis completed (%s characters)", analysis_length)
# Prepare successful response
result = {
"success": True,
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
"analysis": analysis or "There was a problem with the request and the image could not be analyzed.",
}
debug_call_data["success"] = True
debug_call_data["analysis_length"] = analysis_length
# Log debug information
_debug.log_call("vision_analyze_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"Error analyzing image: {str(e)}"
logger.error("%s", error_msg, exc_info=True)
# Prepare error response
result = {
"success": False,
"analysis": "There was a problem with the request and the image could not be analyzed."
"analysis": "There was a problem with the request and the image could not be analyzed.",
}
debug_call_data["error"] = error_msg
_debug.log_call("vision_analyze_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
finally:
# Clean up temporary image file (but NOT local/cached files)
if should_cleanup and temp_image_path and temp_image_path.exists():
@ -377,9 +375,7 @@ async def vision_analyze_tool(
temp_image_path.unlink()
logger.debug("Cleaned up temporary image file")
except Exception as cleanup_error:
logger.warning(
"Could not delete temporary file: %s", cleanup_error, exc_info=True
)
logger.warning("Could not delete temporary file: %s", cleanup_error, exc_info=True)
def check_vision_requirements() -> bool:
@ -387,10 +383,10 @@ def check_vision_requirements() -> bool:
return _aux_async_client is not None
def get_debug_session_info() -> Dict[str, Any]:
def get_debug_session_info() -> dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
@ -403,27 +399,27 @@ if __name__ == "__main__":
"""
print("👁️ Vision Tools Module")
print("=" * 40)
# Check if vision model is available
api_available = check_vision_requirements()
if not api_available:
print("❌ No auxiliary vision model available")
print("Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.")
exit(1)
else:
print(f"✅ Vision model available: {DEFAULT_VISION_MODEL}")
print("🛠️ Vision tools ready for use!")
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
# Show debug mode status
if _debug.active:
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{_debug.session_id}.json")
else:
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from vision_tools import vision_analyze_tool")
print(" import asyncio")
@ -435,14 +431,14 @@ if __name__ == "__main__":
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nExample prompts:")
print(" - 'What architectural style is this building?'")
print(" - 'Describe the emotions and mood in this image'")
print(" - 'What text can you read in this image?'")
print(" - 'Identify any safety hazards visible'")
print(" - 'What products or brands are shown?'")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export VISION_TOOLS_DEBUG=true")
@ -461,30 +457,24 @@ VISION_ANALYZE_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "Image URL (http/https) or local file path to analyze."
},
"image_url": {"type": "string", "description": "Image URL (http/https) or local file path to analyze."},
"question": {
"type": "string",
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question."
}
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question.",
},
},
"required": ["image_url", "question"]
}
"required": ["image_url", "question"],
},
}
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
def _handle_vision_analyze(args: dict[str, Any], **kw: Any) -> Awaitable[str]:
image_url = args.get("image_url", "")
question = args.get("question", "")
full_prompt = (
"Fully describe and explain everything about this image, then answer the "
f"following question:\n\n{question}"
f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
)
model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip()
or DEFAULT_VISION_MODEL
or "google/gemini-3-flash-preview")
model = os.getenv("AUXILIARY_VISION_MODEL", "").strip() or DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview"
return vision_analyze_tool(image_url, full_prompt, model)

File diff suppressed because it is too large Load diff