mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 02:01:47 +00:00
feat: devex help, add Makefile, ruff, pre-commit, and modernize CI
This commit is contained in:
parent
172a38c344
commit
f4d7e6a29e
111 changed files with 11655 additions and 10200 deletions
|
|
@ -16,249 +16,222 @@ for the AI agent to access all capabilities.
|
|||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
check_terminal_requirements,
|
||||
cleanup_vm,
|
||||
cleanup_all_environments,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
|
||||
from .skills_tool import (
|
||||
skills_list,
|
||||
skill_view,
|
||||
check_skills_requirements,
|
||||
SKILLS_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .skill_manager_tool import (
|
||||
skill_manage,
|
||||
check_skill_manage_requirements,
|
||||
SKILL_MANAGE_SCHEMA
|
||||
)
|
||||
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from .browser_tool import (
|
||||
browser_navigate,
|
||||
browser_snapshot,
|
||||
browser_click,
|
||||
browser_type,
|
||||
browser_scroll,
|
||||
BROWSER_TOOL_SCHEMAS,
|
||||
browser_back,
|
||||
browser_press,
|
||||
browser_click,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_navigate,
|
||||
browser_press,
|
||||
browser_scroll,
|
||||
browser_snapshot,
|
||||
browser_type,
|
||||
browser_vision,
|
||||
cleanup_browser,
|
||||
cleanup_all_browsers,
|
||||
get_active_browser_sessions,
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
schedule_cronjob,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
)
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
search_tool,
|
||||
get_file_tools,
|
||||
clear_file_ops_cache,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
text_to_speech_tool,
|
||||
check_tts_requirements,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
todo_tool,
|
||||
check_todo_requirements,
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
cleanup_all_browsers,
|
||||
cleanup_browser,
|
||||
get_active_browser_sessions,
|
||||
)
|
||||
|
||||
# Clarifying questions tool (interactive Q&A with the user)
|
||||
from .clarify_tool import (
|
||||
clarify_tool,
|
||||
check_clarify_requirements,
|
||||
CLARIFY_SCHEMA,
|
||||
check_clarify_requirements,
|
||||
clarify_tool,
|
||||
)
|
||||
|
||||
# Code execution sandbox (programmatic tool calling)
|
||||
from .code_execution_tool import (
|
||||
execute_code,
|
||||
check_sandbox_requirements,
|
||||
EXECUTE_CODE_SCHEMA,
|
||||
check_sandbox_requirements,
|
||||
execute_code,
|
||||
)
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
schedule_cronjob,
|
||||
)
|
||||
|
||||
# Subagent delegation (spawn child agents with isolated context)
|
||||
from .delegate_tool import (
|
||||
delegate_task,
|
||||
check_delegate_requirements,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
clear_file_ops_cache,
|
||||
get_file_tools,
|
||||
patch_tool,
|
||||
read_file_tool,
|
||||
search_tool,
|
||||
write_file_tool,
|
||||
)
|
||||
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
|
||||
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
rl_check_status,
|
||||
rl_edit_config,
|
||||
rl_get_current_config,
|
||||
rl_get_results,
|
||||
rl_list_environments,
|
||||
rl_list_runs,
|
||||
rl_select_environment,
|
||||
rl_start_training,
|
||||
rl_stop_training,
|
||||
rl_test_inference,
|
||||
)
|
||||
from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage
|
||||
from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
TERMINAL_TOOL_DESCRIPTION,
|
||||
check_terminal_requirements,
|
||||
cleanup_all_environments,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
terminal_tool,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
check_todo_requirements,
|
||||
todo_tool,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
check_tts_requirements,
|
||||
text_to_speech_tool,
|
||||
)
|
||||
from .vision_tools import check_vision_requirements, vision_analyze_tool
|
||||
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
|
||||
|
||||
|
||||
# File tools have no external requirements - they use the terminal backend
|
||||
def check_file_requirements():
|
||||
"""File tools only require terminal backend to be available."""
|
||||
from .terminal_tool import check_terminal_requirements
|
||||
|
||||
return check_terminal_requirements()
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
'web_extract_tool',
|
||||
'web_crawl_tool',
|
||||
'check_firecrawl_api_key',
|
||||
"web_search_tool",
|
||||
"web_extract_tool",
|
||||
"web_crawl_tool",
|
||||
"check_firecrawl_api_key",
|
||||
# Terminal tools (mini-swe-agent backend)
|
||||
'terminal_tool',
|
||||
'check_terminal_requirements',
|
||||
'cleanup_vm',
|
||||
'cleanup_all_environments',
|
||||
'get_active_environments_info',
|
||||
'register_task_env_overrides',
|
||||
'clear_task_env_overrides',
|
||||
'TERMINAL_TOOL_DESCRIPTION',
|
||||
"terminal_tool",
|
||||
"check_terminal_requirements",
|
||||
"cleanup_vm",
|
||||
"cleanup_all_environments",
|
||||
"get_active_environments_info",
|
||||
"register_task_env_overrides",
|
||||
"clear_task_env_overrides",
|
||||
"TERMINAL_TOOL_DESCRIPTION",
|
||||
# Vision tools
|
||||
'vision_analyze_tool',
|
||||
'check_vision_requirements',
|
||||
"vision_analyze_tool",
|
||||
"check_vision_requirements",
|
||||
# MoA tools
|
||||
'mixture_of_agents_tool',
|
||||
'check_moa_requirements',
|
||||
"mixture_of_agents_tool",
|
||||
"check_moa_requirements",
|
||||
# Image generation tools
|
||||
'image_generate_tool',
|
||||
'check_image_generation_requirements',
|
||||
"image_generate_tool",
|
||||
"check_image_generation_requirements",
|
||||
# Skills tools
|
||||
'skills_list',
|
||||
'skill_view',
|
||||
'check_skills_requirements',
|
||||
'SKILLS_TOOL_DESCRIPTION',
|
||||
"skills_list",
|
||||
"skill_view",
|
||||
"check_skills_requirements",
|
||||
"SKILLS_TOOL_DESCRIPTION",
|
||||
# Skill management
|
||||
'skill_manage',
|
||||
'check_skill_manage_requirements',
|
||||
'SKILL_MANAGE_SCHEMA',
|
||||
"skill_manage",
|
||||
"check_skill_manage_requirements",
|
||||
"SKILL_MANAGE_SCHEMA",
|
||||
# Browser automation tools
|
||||
'browser_navigate',
|
||||
'browser_snapshot',
|
||||
'browser_click',
|
||||
'browser_type',
|
||||
'browser_scroll',
|
||||
'browser_back',
|
||||
'browser_press',
|
||||
'browser_close',
|
||||
'browser_get_images',
|
||||
'browser_vision',
|
||||
'cleanup_browser',
|
||||
'cleanup_all_browsers',
|
||||
'get_active_browser_sessions',
|
||||
'check_browser_requirements',
|
||||
'BROWSER_TOOL_SCHEMAS',
|
||||
"browser_navigate",
|
||||
"browser_snapshot",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_press",
|
||||
"browser_close",
|
||||
"browser_get_images",
|
||||
"browser_vision",
|
||||
"cleanup_browser",
|
||||
"cleanup_all_browsers",
|
||||
"get_active_browser_sessions",
|
||||
"check_browser_requirements",
|
||||
"BROWSER_TOOL_SCHEMAS",
|
||||
# Cronjob management tools (CLI-only)
|
||||
'schedule_cronjob',
|
||||
'list_cronjobs',
|
||||
'remove_cronjob',
|
||||
'check_cronjob_requirements',
|
||||
'get_cronjob_tool_definitions',
|
||||
'SCHEDULE_CRONJOB_SCHEMA',
|
||||
'LIST_CRONJOBS_SCHEMA',
|
||||
'REMOVE_CRONJOB_SCHEMA',
|
||||
"schedule_cronjob",
|
||||
"list_cronjobs",
|
||||
"remove_cronjob",
|
||||
"check_cronjob_requirements",
|
||||
"get_cronjob_tool_definitions",
|
||||
"SCHEDULE_CRONJOB_SCHEMA",
|
||||
"LIST_CRONJOBS_SCHEMA",
|
||||
"REMOVE_CRONJOB_SCHEMA",
|
||||
# RL Training tools
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_list_runs',
|
||||
'rl_test_inference',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
"rl_list_environments",
|
||||
"rl_select_environment",
|
||||
"rl_get_current_config",
|
||||
"rl_edit_config",
|
||||
"rl_start_training",
|
||||
"rl_check_status",
|
||||
"rl_stop_training",
|
||||
"rl_get_results",
|
||||
"rl_list_runs",
|
||||
"rl_test_inference",
|
||||
"check_rl_api_keys",
|
||||
"get_missing_keys",
|
||||
# File manipulation tools
|
||||
'read_file_tool',
|
||||
'write_file_tool',
|
||||
'patch_tool',
|
||||
'search_tool',
|
||||
'get_file_tools',
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
"read_file_tool",
|
||||
"write_file_tool",
|
||||
"patch_tool",
|
||||
"search_tool",
|
||||
"get_file_tools",
|
||||
"clear_file_ops_cache",
|
||||
"check_file_requirements",
|
||||
# Text-to-speech tools
|
||||
'text_to_speech_tool',
|
||||
'check_tts_requirements',
|
||||
"text_to_speech_tool",
|
||||
"check_tts_requirements",
|
||||
# Planning & task management tool
|
||||
'todo_tool',
|
||||
'check_todo_requirements',
|
||||
'TODO_SCHEMA',
|
||||
'TodoStore',
|
||||
"todo_tool",
|
||||
"check_todo_requirements",
|
||||
"TODO_SCHEMA",
|
||||
"TodoStore",
|
||||
# Clarifying questions tool
|
||||
'clarify_tool',
|
||||
'check_clarify_requirements',
|
||||
'CLARIFY_SCHEMA',
|
||||
"clarify_tool",
|
||||
"check_clarify_requirements",
|
||||
"CLARIFY_SCHEMA",
|
||||
# Code execution sandbox
|
||||
'execute_code',
|
||||
'check_sandbox_requirements',
|
||||
'EXECUTE_CODE_SCHEMA',
|
||||
"execute_code",
|
||||
"check_sandbox_requirements",
|
||||
"EXECUTE_CODE_SCHEMA",
|
||||
# Subagent delegation
|
||||
'delegate_task',
|
||||
'check_delegate_requirements',
|
||||
'DELEGATE_TASK_SCHEMA',
|
||||
"delegate_task",
|
||||
"check_delegate_requirements",
|
||||
"DELEGATE_TASK_SCHEMA",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,32 +20,32 @@ logger = logging.getLogger(__name__)
|
|||
# =========================================================================
|
||||
|
||||
DANGEROUS_PATTERNS = [
|
||||
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
|
||||
(r'\brm\s+-[^\s]*r', "recursive delete"),
|
||||
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
|
||||
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
|
||||
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
|
||||
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
|
||||
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
|
||||
(r'\bmkfs\b', "format filesystem"),
|
||||
(r'\bdd\s+.*if=', "disk copy"),
|
||||
(r'>\s*/dev/sd', "write to block device"),
|
||||
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
|
||||
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
|
||||
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
|
||||
(r'>\s*/etc/', "overwrite system config"),
|
||||
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
|
||||
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
|
||||
(r'\bpkill\s+-9\b', "force kill processes"),
|
||||
(r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"),
|
||||
(r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"),
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
|
||||
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
|
||||
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
|
||||
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
|
||||
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
|
||||
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
|
||||
(r'\bfind\b.*-delete\b', "find -delete"),
|
||||
(r"\brm\s+(-[^\s]*\s+)*/", "delete in root path"),
|
||||
(r"\brm\s+-[^\s]*r", "recursive delete"),
|
||||
(r"\brm\s+--recursive\b", "recursive delete (long flag)"),
|
||||
(r"\bchmod\s+(-[^\s]*\s+)*777\b", "world-writable permissions"),
|
||||
(r"\bchmod\s+--recursive\b.*777", "recursive world-writable (long flag)"),
|
||||
(r"\bchown\s+(-[^\s]*)?R\s+root", "recursive chown to root"),
|
||||
(r"\bchown\s+--recursive\b.*root", "recursive chown to root (long flag)"),
|
||||
(r"\bmkfs\b", "format filesystem"),
|
||||
(r"\bdd\s+.*if=", "disk copy"),
|
||||
(r">\s*/dev/sd", "write to block device"),
|
||||
(r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"),
|
||||
(r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"),
|
||||
(r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"),
|
||||
(r">\s*/etc/", "overwrite system config"),
|
||||
(r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"),
|
||||
(r"\bkill\s+-9\s+-1\b", "kill all processes"),
|
||||
(r"\bpkill\s+-9\b", "force kill processes"),
|
||||
(r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"),
|
||||
(r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"),
|
||||
(r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"),
|
||||
(r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"),
|
||||
(r"\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b", "execute remote script via process substitution"),
|
||||
(r"\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)", "overwrite system file via tee"),
|
||||
(r"\bxargs\s+.*\brm\b", "xargs with rm"),
|
||||
(r"\bfind\b.*-exec\s+(/\S*/)?rm\b", "find -exec rm"),
|
||||
(r"\bfind\b.*-delete\b", "find -delete"),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -54,6 +53,7 @@ DANGEROUS_PATTERNS = [
|
|||
# Detection
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def detect_dangerous_command(command: str) -> tuple:
|
||||
"""Check if a command matches any dangerous patterns.
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple:
|
|||
command_lower = command.lower()
|
||||
for pattern, description in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
|
||||
pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
|
||||
pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20]
|
||||
return (True, pattern_key, description)
|
||||
return (False, None, None)
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict):
|
|||
_pending[session_key] = approval
|
||||
|
||||
|
||||
def pop_pending(session_key: str) -> Optional[dict]:
|
||||
def pop_pending(session_key: str) -> dict | None:
|
||||
"""Retrieve and remove a pending approval for a session."""
|
||||
with _lock:
|
||||
return _pending.pop(session_key, None)
|
||||
|
|
@ -133,6 +133,7 @@ def clear_session(session_key: str):
|
|||
# Config persistence for permanent allowlist
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def load_permanent_allowlist() -> set:
|
||||
"""Load permanently allowed command patterns from config.
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ def load_permanent_allowlist() -> set:
|
|||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
patterns = set(config.get("command_allowlist", []) or [])
|
||||
if patterns:
|
||||
|
|
@ -154,6 +156,7 @@ def save_permanent_allowlist(patterns: set):
|
|||
"""Save permanently allowed command patterns to config."""
|
||||
try:
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
config = load_config()
|
||||
config["command_allowlist"] = list(patterns)
|
||||
save_config(config)
|
||||
|
|
@ -165,9 +168,8 @@ def save_permanent_allowlist(patterns: set):
|
|||
# Approval prompting + orchestration
|
||||
# =========================================================================
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str,
|
||||
timeout_seconds: int = 60,
|
||||
approval_callback=None) -> str:
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str:
|
||||
"""Prompt the user to approve a dangerous command (CLI only).
|
||||
|
||||
Args:
|
||||
|
|
@ -188,7 +190,7 @@ def prompt_dangerous_approval(command: str, description: str,
|
|||
print(f" ⚠️ DANGEROUS COMMAND: {description}")
|
||||
print(f" {command[:80]}{'...' if len(command) > 80 else ''}")
|
||||
print()
|
||||
print(f" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print(" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print()
|
||||
sys.stdout.flush()
|
||||
|
||||
|
|
@ -209,13 +211,13 @@ def prompt_dangerous_approval(command: str, description: str,
|
|||
return "deny"
|
||||
|
||||
choice = result["choice"]
|
||||
if choice in ('o', 'once'):
|
||||
if choice in ("o", "once"):
|
||||
print(" ✓ Allowed once")
|
||||
return "once"
|
||||
elif choice in ('s', 'session'):
|
||||
elif choice in ("s", "session"):
|
||||
print(" ✓ Allowed for this session")
|
||||
return "session"
|
||||
elif choice in ('a', 'always'):
|
||||
elif choice in ("a", "always"):
|
||||
print(" ✓ Added to permanent allowlist")
|
||||
return "always"
|
||||
else:
|
||||
|
|
@ -232,8 +234,7 @@ def prompt_dangerous_approval(command: str, description: str,
|
|||
sys.stdout.flush()
|
||||
|
||||
|
||||
def check_dangerous_command(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict:
|
||||
"""Check if a command is dangerous and handle approval.
|
||||
|
||||
This is the main entry point called by terminal_tool before executing
|
||||
|
|
@ -265,11 +266,14 @@ def check_dangerous_command(command: str, env_type: str,
|
|||
return {"approved": True, "message": None}
|
||||
|
||||
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
})
|
||||
submit_pending(
|
||||
session_key,
|
||||
{
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
return {
|
||||
"approved": False,
|
||||
"pattern_key": pattern_key,
|
||||
|
|
@ -279,8 +283,7 @@ def check_dangerous_command(command: str, env_type: str,
|
|||
"message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...",
|
||||
}
|
||||
|
||||
choice = prompt_dangerous_approval(command, description,
|
||||
approval_callback=approval_callback)
|
||||
choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback)
|
||||
|
||||
if choice == "deny":
|
||||
return {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -12,8 +12,7 @@ a thin dispatcher that delegates to a platform-provided callback.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
# Maximum number of predefined choices the agent can offer.
|
||||
# A 5th "Other (type your answer)" option is always appended by the UI.
|
||||
|
|
@ -22,8 +21,8 @@ MAX_CHOICES = 4
|
|||
|
||||
def clarify_tool(
|
||||
question: str,
|
||||
choices: Optional[List[str]] = None,
|
||||
callback: Optional[Callable] = None,
|
||||
choices: list[str] | None = None,
|
||||
callback: Callable | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ask the user a question, optionally with multiple-choice options.
|
||||
|
|
@ -68,11 +67,14 @@ def clarify_tool(
|
|||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def check_clarify_requirements() -> bool:
|
||||
|
|
@ -133,8 +135,7 @@ registry.register(
|
|||
toolset="clarify",
|
||||
schema=CLARIFY_SCHEMA,
|
||||
handler=lambda args, **kw: clarify_tool(
|
||||
question=args.get("question", ""),
|
||||
choices=args.get("choices"),
|
||||
callback=kw.get("callback")),
|
||||
question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback")
|
||||
),
|
||||
check_fn=check_clarify_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import time
|
|||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
# Availability gate: UDS requires a POSIX OS
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -40,21 +40,23 @@ SANDBOX_AVAILABLE = sys.platform != "win32"
|
|||
|
||||
# The 7 tools allowed inside the sandbox. The intersection of this list
|
||||
# and the session's enabled tools determines which stubs are generated.
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset([
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
])
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset(
|
||||
[
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
]
|
||||
)
|
||||
|
||||
# Resource limit defaults (overridable via config.yaml → code_execution.*)
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_MAX_TOOL_CALLS = 50
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
|
||||
|
||||
def check_sandbox_requirements() -> bool:
|
||||
|
|
@ -114,7 +116,7 @@ _TOOL_STUBS = {
|
|||
}
|
||||
|
||||
|
||||
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
||||
def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
|
||||
"""
|
||||
Build the source code for the hermes_tools.py stub module.
|
||||
|
||||
|
|
@ -128,11 +130,7 @@ def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
|||
if tool_name not in _TOOL_STUBS:
|
||||
continue
|
||||
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
|
||||
stub_functions.append(
|
||||
f"def {func_name}({sig}):\n"
|
||||
f" {doc}\n"
|
||||
f" return _call({func_name!r}, {args_expr})\n"
|
||||
)
|
||||
stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n")
|
||||
export_names.append(func_name)
|
||||
|
||||
header = '''\
|
||||
|
|
@ -223,7 +221,7 @@ def _rpc_server_loop(
|
|||
server_sock: socket.socket,
|
||||
task_id: str,
|
||||
tool_call_log: list,
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
max_tool_calls: int,
|
||||
allowed_tools: frozenset,
|
||||
):
|
||||
|
|
@ -243,7 +241,7 @@ def _rpc_server_loop(
|
|||
while True:
|
||||
try:
|
||||
chunk = conn.recv(65536)
|
||||
except socket.timeout:
|
||||
except TimeoutError:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
|
|
@ -270,23 +268,22 @@ def _rpc_server_loop(
|
|||
# Enforce the allow-list
|
||||
if tool_name not in allowed_tools:
|
||||
available = ", ".join(sorted(allowed_tools))
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool '{tool_name}' is not available in execute_code. "
|
||||
f"Available: {available}"
|
||||
)
|
||||
})
|
||||
resp = json.dumps(
|
||||
{"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")}
|
||||
)
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
# Enforce tool call limit
|
||||
if tool_call_counter[0] >= max_tool_calls:
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
})
|
||||
resp = json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
}
|
||||
)
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
|
|
@ -303,9 +300,7 @@ def _rpc_server_loop(
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
try:
|
||||
result = handle_function_call(
|
||||
tool_name, tool_args, task_id=task_id
|
||||
)
|
||||
result = handle_function_call(tool_name, tool_args, task_id=task_id)
|
||||
finally:
|
||||
sys.stdout.close()
|
||||
sys.stderr.close()
|
||||
|
|
@ -318,15 +313,17 @@ def _rpc_server_loop(
|
|||
|
||||
# Log for observability
|
||||
args_preview = str(tool_args)[:80]
|
||||
tool_call_log.append({
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
})
|
||||
tool_call_log.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
}
|
||||
)
|
||||
|
||||
conn.sendall((result + "\n").encode())
|
||||
|
||||
except socket.timeout:
|
||||
except TimeoutError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
|
|
@ -342,10 +339,11 @@ def _rpc_server_loop(
|
|||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_code(
|
||||
code: str,
|
||||
task_id: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
task_id: str | None = None,
|
||||
enabled_tools: list[str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Run a Python script in a sandboxed child process with RPC access
|
||||
|
|
@ -361,9 +359,7 @@ def execute_code(
|
|||
JSON string with execution results.
|
||||
"""
|
||||
if not SANDBOX_AVAILABLE:
|
||||
return json.dumps({
|
||||
"error": "execute_code is not available on Windows. Use normal tool calls instead."
|
||||
})
|
||||
return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."})
|
||||
|
||||
if not code or not code.strip():
|
||||
return json.dumps({"error": "No code provided."})
|
||||
|
|
@ -397,9 +393,7 @@ def execute_code(
|
|||
|
||||
try:
|
||||
# Write the auto-generated hermes_tools module
|
||||
tools_src = generate_hermes_tools_module(
|
||||
list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)
|
||||
)
|
||||
tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS))
|
||||
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
|
||||
f.write(tools_src)
|
||||
|
||||
|
|
@ -415,8 +409,12 @@ def execute_code(
|
|||
rpc_thread = threading.Thread(
|
||||
target=_rpc_server_loop,
|
||||
args=(
|
||||
server_sock, task_id, tool_call_log,
|
||||
tool_call_counter, max_tool_calls, sandbox_tools,
|
||||
server_sock,
|
||||
task_id,
|
||||
tool_call_log,
|
||||
tool_call_counter,
|
||||
max_tool_calls,
|
||||
sandbox_tools,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
|
@ -426,11 +424,24 @@ def execute_code(
|
|||
# Build a minimal environment for the child. We intentionally exclude
|
||||
# API keys and tokens to prevent credential exfiltration from LLM-
|
||||
# generated scripts. The child accesses tools via RPC, not direct API.
|
||||
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
|
||||
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
|
||||
"PASSWD", "AUTH")
|
||||
_SAFE_ENV_PREFIXES = (
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_",
|
||||
"TERM",
|
||||
"TMPDIR",
|
||||
"TMP",
|
||||
"TEMP",
|
||||
"SHELL",
|
||||
"LOGNAME",
|
||||
"XDG_",
|
||||
"PYTHONPATH",
|
||||
"VIRTUAL_ENV",
|
||||
"CONDA",
|
||||
)
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH")
|
||||
child_env = {}
|
||||
for k, v in os.environ.items():
|
||||
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
|
||||
|
|
@ -515,7 +526,7 @@ def execute_code(
|
|||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Build response
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"status": status,
|
||||
"output": stdout_text,
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
|
|
@ -538,17 +549,21 @@ def execute_code(
|
|||
except Exception as exc:
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
logging.exception("execute_code failed")
|
||||
return json.dumps({
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup temp dir and socket
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
logger.debug("Could not clean temp dir: %s", e)
|
||||
|
|
@ -592,6 +607,7 @@ def _load_config() -> dict:
|
|||
"""Load code_execution config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("code_execution", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
|
@ -604,27 +620,37 @@ def _load_config() -> dict:
|
|||
# Per-tool documentation lines for the execute_code description.
|
||||
# Ordered to match the canonical display order.
|
||||
_TOOL_DOC_LINES = [
|
||||
("web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
|
||||
("web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
|
||||
("read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
|
||||
("write_file",
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file."),
|
||||
("search_files",
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
|
||||
("patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file."),
|
||||
("terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
|
||||
(
|
||||
"web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}',
|
||||
),
|
||||
(
|
||||
"web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown',
|
||||
),
|
||||
(
|
||||
"read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}',
|
||||
),
|
||||
("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."),
|
||||
(
|
||||
"search_files",
|
||||
' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n'
|
||||
' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}',
|
||||
),
|
||||
(
|
||||
"patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file.",
|
||||
),
|
||||
(
|
||||
"terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -639,9 +665,7 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
|
|||
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# Build tool documentation lines for only the enabled tools
|
||||
tool_lines = "\n".join(
|
||||
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
|
||||
)
|
||||
tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools)
|
||||
|
||||
# Build example import list from enabled tools
|
||||
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
|
||||
|
|
@ -702,8 +726,7 @@ registry.register(
|
|||
toolset="code_execution",
|
||||
schema=EXECUTE_CODE_SCHEMA,
|
||||
handler=lambda args, **kw: execute_code(
|
||||
code=args.get("code", ""),
|
||||
task_id=kw.get("task_id"),
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools")
|
||||
),
|
||||
check_fn=check_sandbox_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,37 +11,44 @@ The prompt must contain ALL necessary information.
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
# Import from cron module (will be available when properly installed)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import create_job, get_job, list_jobs, remove_job
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cron prompt scanning — critical-severity patterns only, since cron prompts
|
||||
# run in fresh sessions with full tool access.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CRON_THREAT_PATTERNS = [
|
||||
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'/etc/sudoers|visudo', "sudoers_mod"),
|
||||
(r'rm\s+-rf\s+/', "destructive_root_rm"),
|
||||
(r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"/etc/sudoers|visudo", "sudoers_mod"),
|
||||
(r"rm\s+-rf\s+/", "destructive_root_rm"),
|
||||
]
|
||||
|
||||
_CRON_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -60,17 +67,18 @@ def _scan_cron_prompt(prompt: str) -> str:
|
|||
# Tool: schedule_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def schedule_cronjob(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
task_id: str = None
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
task_id: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule an automated task to run the agent on a schedule.
|
||||
|
||||
|
||||
IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session.
|
||||
The agent will have NO memory of this conversation or any prior context.
|
||||
Therefore, the prompt MUST contain ALL necessary information:
|
||||
|
|
@ -78,12 +86,12 @@ def schedule_cronjob(
|
|||
- Specific file paths, URLs, or identifiers
|
||||
- Clear success criteria
|
||||
- Any relevant background information
|
||||
|
||||
|
||||
BAD prompt: "Check on that server issue"
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
https://example.com returns HTTP 200. Report any issues found."
|
||||
|
||||
|
||||
Args:
|
||||
prompt: Complete, self-contained instructions for the future agent.
|
||||
Must include ALL context needed - the agent won't remember anything.
|
||||
|
|
@ -105,7 +113,7 @@ def schedule_cronjob(
|
|||
- "signal": Send to Signal home channel
|
||||
- "telegram:123456": Send to specific chat ID
|
||||
- "signal:+15551234567": Send to specific Signal number
|
||||
|
||||
|
||||
Returns:
|
||||
JSON with job_id, next_run time, and confirmation
|
||||
"""
|
||||
|
|
@ -124,17 +132,10 @@ def schedule_cronjob(
|
|||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
job = create_job(
|
||||
prompt=prompt,
|
||||
schedule=schedule,
|
||||
name=name,
|
||||
repeat=repeat,
|
||||
deliver=deliver,
|
||||
origin=origin
|
||||
)
|
||||
|
||||
job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin)
|
||||
|
||||
# Format repeat info for display
|
||||
times = job["repeat"].get("times")
|
||||
if times is None:
|
||||
|
|
@ -143,23 +144,23 @@ def schedule_cronjob(
|
|||
repeat_display = "once"
|
||||
else:
|
||||
repeat_display = f"{times} times"
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}."
|
||||
}, indent=2)
|
||||
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
SCHEDULE_CRONJOB_SCHEMA = {
|
||||
|
|
@ -177,7 +178,7 @@ The future agent will NOT remember anything from the current conversation.
|
|||
|
||||
SCHEDULE FORMATS:
|
||||
- One-shot: "30m", "2h", "1d" (runs once after delay)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Cron: "0 9 * * *" (cron expression for precise scheduling)
|
||||
- Timestamp: "2026-02-03T14:00:00" (specific date/time)
|
||||
|
||||
|
|
@ -202,27 +203,24 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
|||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation."
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.",
|
||||
},
|
||||
"schedule": {
|
||||
"type": "string",
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-friendly name for the job"
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp",
|
||||
},
|
||||
"name": {"type": "string", "description": "Optional human-friendly name for the job"},
|
||||
"repeat": {
|
||||
"type": "integer",
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs."
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.",
|
||||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'"
|
||||
}
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'",
|
||||
},
|
||||
},
|
||||
"required": ["prompt", "schedule"]
|
||||
}
|
||||
"required": ["prompt", "schedule"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -230,10 +228,11 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
|||
# Tool: list_cronjobs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
"""
|
||||
List all scheduled cronjobs.
|
||||
|
||||
|
||||
Returns information about each job including:
|
||||
- Job ID (needed for removal)
|
||||
- Name
|
||||
|
|
@ -241,16 +240,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
|||
- Repeat status (completed/total or 'forever')
|
||||
- Next scheduled run time
|
||||
- Last run time and status (if any)
|
||||
|
||||
|
||||
Args:
|
||||
include_disabled: Whether to include disabled/completed jobs
|
||||
|
||||
|
||||
Returns:
|
||||
JSON array of all scheduled jobs
|
||||
"""
|
||||
try:
|
||||
jobs = list_jobs(include_disabled=include_disabled)
|
||||
|
||||
|
||||
formatted_jobs = []
|
||||
for job in jobs:
|
||||
# Format repeat status
|
||||
|
|
@ -260,31 +259,26 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
|||
repeat_status = "forever"
|
||||
else:
|
||||
repeat_status = f"{completed}/{times}"
|
||||
|
||||
formatted_jobs.append({
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True)
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"count": len(formatted_jobs),
|
||||
"jobs": formatted_jobs
|
||||
}, indent=2)
|
||||
|
||||
|
||||
formatted_jobs.append(
|
||||
{
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True),
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
LIST_CRONJOBS_SCHEMA = {
|
||||
|
|
@ -302,11 +296,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
|||
"properties": {
|
||||
"include_disabled": {
|
||||
"type": "boolean",
|
||||
"description": "Include disabled/completed jobs in the list (default: false)"
|
||||
"description": "Include disabled/completed jobs in the list (default: false)",
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -314,48 +308,45 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
|||
# Tool: remove_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def remove_cronjob(job_id: str, task_id: str = None) -> str:
|
||||
"""
|
||||
Remove a scheduled cronjob by its ID.
|
||||
|
||||
|
||||
Use list_cronjobs first to find the job_id of the job you want to remove.
|
||||
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to remove (from list_cronjobs output)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON confirmation of removal
|
||||
"""
|
||||
try:
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs."
|
||||
}, indent=2)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
removed = remove_job(job_id)
|
||||
if removed:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {
|
||||
"id": job_id,
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"]
|
||||
}
|
||||
}, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Failed to remove job '{job_id}'"
|
||||
}, indent=2)
|
||||
|
||||
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
REMOVE_CRONJOB_SCHEMA = {
|
||||
|
|
@ -368,13 +359,10 @@ use this to cancel a job before it completes.""",
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the cronjob to remove (from list_cronjobs output)"
|
||||
}
|
||||
"job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"}
|
||||
},
|
||||
"required": ["job_id"]
|
||||
}
|
||||
"required": ["job_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -382,44 +370,34 @@ use this to cancel a job before it completes.""",
|
|||
# Requirements check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def check_cronjob_requirements() -> bool:
|
||||
"""
|
||||
Check if cronjob tools can be used.
|
||||
|
||||
|
||||
Available in interactive CLI mode and gateway/messaging platforms.
|
||||
Cronjobs are server-side scheduled tasks so they work from any interface.
|
||||
"""
|
||||
return bool(
|
||||
os.getenv("HERMES_INTERACTIVE")
|
||||
or os.getenv("HERMES_GATEWAY_SESSION")
|
||||
or os.getenv("HERMES_EXEC_ASK")
|
||||
)
|
||||
return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_cronjob_tool_definitions():
|
||||
"""Return tool definitions for cronjob management."""
|
||||
return [
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
]
|
||||
return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA]
|
||||
|
||||
|
||||
# For direct testing
|
||||
if __name__ == "__main__":
|
||||
# Test the tools
|
||||
print("Testing schedule_cronjob:")
|
||||
result = schedule_cronjob(
|
||||
prompt="Test prompt for cron job",
|
||||
schedule="5m",
|
||||
name="Test Job"
|
||||
)
|
||||
result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job")
|
||||
print(result)
|
||||
|
||||
|
||||
print("\nTesting list_cronjobs:")
|
||||
result = list_cronjobs()
|
||||
print(result)
|
||||
|
|
@ -438,7 +416,8 @@ registry.register(
|
|||
name=args.get("name"),
|
||||
repeat=args.get("repeat"),
|
||||
deliver=args.get("deliver"),
|
||||
task_id=kw.get("task_id")),
|
||||
task_id=kw.get("task_id"),
|
||||
),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
|
|
@ -446,16 +425,14 @@ registry.register(
|
|||
toolset="cronjob",
|
||||
schema=LIST_CRONJOBS_SCHEMA,
|
||||
handler=lambda args, **kw: list_cronjobs(
|
||||
include_disabled=args.get("include_disabled", False),
|
||||
task_id=kw.get("task_id")),
|
||||
include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id")
|
||||
),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
name="remove_cronjob",
|
||||
toolset="cronjob",
|
||||
schema=REMOVE_CRONJOB_SCHEMA,
|
||||
handler=lambda args, **kw: remove_cronjob(
|
||||
job_id=args.get("job_id", ""),
|
||||
task_id=kw.get("task_id")),
|
||||
handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import logging
|
|||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -44,27 +44,28 @@ class DebugSession:
|
|||
self.enabled = os.getenv(env_var, "false").lower() == "true"
|
||||
self.session_id = str(uuid.uuid4()) if self.enabled else ""
|
||||
self.log_dir = Path("./logs")
|
||||
self._calls: list[Dict[str, Any]] = []
|
||||
self._calls: list[dict[str, Any]] = []
|
||||
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
|
||||
|
||||
if self.enabled:
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s",
|
||||
tool_name, self.session_id)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id)
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
|
||||
def log_call(self, call_name: str, call_data: dict[str, Any]) -> None:
|
||||
"""Append a tool-call entry to the in-memory log."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self._calls.append({
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
})
|
||||
self._calls.append(
|
||||
{
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
}
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Flush the in-memory log to a JSON file in the logs directory."""
|
||||
|
|
@ -87,7 +88,7 @@ class DebugSession:
|
|||
except Exception as e:
|
||||
logger.error("Error saving %s debug log: %s", self.tool_name, e)
|
||||
|
||||
def get_session_info(self) -> Dict[str, Any]:
|
||||
def get_session_info(self) -> dict[str, Any]:
|
||||
"""Return a summary dict suitable for returning from get_debug_session_info()."""
|
||||
if not self.enabled:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -20,21 +20,22 @@ import contextlib
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Tools that children must never have access to
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset([
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
])
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset(
|
||||
[
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
]
|
||||
)
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = 3
|
||||
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
|
||||
|
|
@ -47,7 +48,7 @@ def check_delegate_requirements() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
||||
def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
|
||||
"""Build a focused system prompt for a child agent."""
|
||||
parts = [
|
||||
"You are a focused subagent working on a specific delegated task.",
|
||||
|
|
@ -69,15 +70,18 @@ def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
|||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
|
||||
def _strip_blocked_tools(toolsets: list[str]) -> list[str]:
|
||||
"""Remove toolsets that contain only blocked tools."""
|
||||
blocked_toolset_names = {
|
||||
"delegation", "clarify", "memory", "code_execution",
|
||||
"delegation",
|
||||
"clarify",
|
||||
"memory",
|
||||
"code_execution",
|
||||
}
|
||||
return [t for t in toolsets if t not in blocked_toolset_names]
|
||||
|
||||
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None:
|
||||
"""Build a callback that relays child agent tool calls to the parent display.
|
||||
|
||||
Two display paths:
|
||||
|
|
@ -87,8 +91,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
Returns None if no display mechanism is available, in which case the
|
||||
child agent runs with no progress callback (identical to current behavior).
|
||||
"""
|
||||
spinner = getattr(parent_agent, '_delegate_spinner', None)
|
||||
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
|
||||
spinner = getattr(parent_agent, "_delegate_spinner", None)
|
||||
parent_cb = getattr(parent_agent, "tool_progress_callback", None)
|
||||
|
||||
if not spinner and not parent_cb:
|
||||
return None # No display → no callback → zero behavior change
|
||||
|
|
@ -98,7 +102,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
|
||||
# Gateway: batch tool names, flush periodically
|
||||
_BATCH_SIZE = 5
|
||||
_batch: List[str] = []
|
||||
_batch: list[str] = []
|
||||
|
||||
def _callback(tool_name: str, preview: str = None):
|
||||
# Special "_thinking" event: model produced text content (reasoning)
|
||||
|
|
@ -106,7 +110,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
if spinner:
|
||||
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
|
||||
try:
|
||||
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
|
||||
spinner.print_above(f' {prefix}├─ 💭 "{short}"')
|
||||
except Exception:
|
||||
pass
|
||||
# Don't relay thinking to gateway (too noisy for chat)
|
||||
|
|
@ -116,17 +120,25 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
if spinner:
|
||||
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
|
||||
tool_emojis = {
|
||||
"terminal": "💻", "web_search": "🔍", "web_extract": "📄",
|
||||
"read_file": "📖", "write_file": "✍️", "patch": "🔧",
|
||||
"search_files": "🔎", "list_directory": "📂",
|
||||
"browser_navigate": "🌐", "browser_click": "👆",
|
||||
"text_to_speech": "🔊", "image_generate": "🎨",
|
||||
"vision_analyze": "👁️", "process": "⚙️",
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search_files": "🔎",
|
||||
"list_directory": "📂",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"text_to_speech": "🔊",
|
||||
"image_generate": "🎨",
|
||||
"vision_analyze": "👁️",
|
||||
"process": "⚙️",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚡")
|
||||
line = f" {prefix}├─ {emoji} {tool_name}"
|
||||
if short:
|
||||
line += f" \"{short}\""
|
||||
line += f' "{short}"'
|
||||
try:
|
||||
spinner.print_above(line)
|
||||
except Exception:
|
||||
|
|
@ -159,13 +171,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
def _run_single_child(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
context: Optional[str],
|
||||
toolsets: Optional[List[str]],
|
||||
model: Optional[str],
|
||||
context: str | None,
|
||||
toolsets: list[str] | None,
|
||||
model: str | None,
|
||||
max_iterations: int,
|
||||
parent_agent,
|
||||
task_count: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Spawn and run a single child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
|
|
@ -216,7 +228,7 @@ def _run_single_child(
|
|||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
session_db=getattr(parent_agent, "_session_db", None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
|
|
@ -226,10 +238,10 @@ def _run_single_child(
|
|||
)
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
parent_agent._active_children.append(child)
|
||||
|
||||
# Run with stdout/stderr suppressed to prevent interleaved output
|
||||
|
|
@ -238,7 +250,7 @@ def _run_single_child(
|
|||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
# Flush any remaining batched progress to gateway
|
||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||
if child_progress_cb and hasattr(child_progress_cb, "_flush"):
|
||||
try:
|
||||
child_progress_cb._flush()
|
||||
except Exception:
|
||||
|
|
@ -258,7 +270,7 @@ def _run_single_child(
|
|||
else:
|
||||
status = "failed"
|
||||
|
||||
entry: Dict[str, Any] = {
|
||||
entry: dict[str, Any] = {
|
||||
"task_index": task_index,
|
||||
"status": status,
|
||||
"summary": summary,
|
||||
|
|
@ -284,7 +296,7 @@ def _run_single_child(
|
|||
|
||||
finally:
|
||||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
try:
|
||||
parent_agent._active_children.remove(child)
|
||||
except (ValueError, UnboundLocalError):
|
||||
|
|
@ -292,11 +304,11 @@ def _run_single_child(
|
|||
|
||||
|
||||
def delegate_task(
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
tasks: Optional[List[Dict[str, Any]]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
goal: str | None = None,
|
||||
context: str | None = None,
|
||||
toolsets: list[str] | None = None,
|
||||
tasks: list[dict[str, Any]] | None = None,
|
||||
max_iterations: int | None = None,
|
||||
parent_agent=None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -312,14 +324,11 @@ def delegate_task(
|
|||
return json.dumps({"error": "delegate_task requires a parent agent context."})
|
||||
|
||||
# Depth limit
|
||||
depth = getattr(parent_agent, '_delegate_depth', 0)
|
||||
depth = getattr(parent_agent, "_delegate_depth", 0)
|
||||
if depth >= MAX_DEPTH:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Delegation depth limit reached ({MAX_DEPTH}). "
|
||||
"Subagents cannot spawn further subagents."
|
||||
)
|
||||
})
|
||||
return json.dumps(
|
||||
{"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")}
|
||||
)
|
||||
|
||||
# Load config
|
||||
cfg = _load_config()
|
||||
|
|
@ -366,7 +375,7 @@ def delegate_task(
|
|||
else:
|
||||
# Batch -- run in parallel with per-task progress lines
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
spinner_ref = getattr(parent_agent, "_delegate_spinner", None)
|
||||
|
||||
# Save stdout/stderr before the executor — redirect_stdout in child
|
||||
# threads races on sys.stdout and can leave it as devnull permanently.
|
||||
|
|
@ -412,7 +421,7 @@ def delegate_task(
|
|||
status = entry.get("status", "?")
|
||||
icon = "✓" if status == "completed" else "✗"
|
||||
remaining = n_tasks - completed_count
|
||||
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
|
||||
completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)"
|
||||
if spinner_ref:
|
||||
try:
|
||||
spinner_ref.print_above(completion_line)
|
||||
|
|
@ -437,16 +446,20 @@ def delegate_task(
|
|||
|
||||
total_duration = round(time.monotonic() - overall_start, 2)
|
||||
|
||||
return json.dumps({
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load delegation config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("delegation", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
|
@ -537,10 +550,7 @@ DELEGATE_TASK_SCHEMA = {
|
|||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max tool-calling turns per subagent (default: 50). "
|
||||
"Only set lower for simple tasks."
|
||||
),
|
||||
"description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
|
@ -561,6 +571,7 @@ registry.register(
|
|||
toolsets=args.get("toolsets"),
|
||||
tasks=args.get("tasks"),
|
||||
max_iterations=args.get("max_iterations"),
|
||||
parent_agent=kw.get("parent_agent")),
|
||||
parent_agent=kw.get("parent_agent"),
|
||||
),
|
||||
check_fn=check_delegate_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Base class for all Hermes execution environment backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
|
@ -34,9 +34,9 @@ class BaseEnvironment(ABC):
|
|||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
|
||||
|
|
@ -62,10 +62,10 @@ class BaseEnvironment(ABC):
|
|||
def _prepare_command(self, command: str) -> str:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import shlex
|
|||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -32,8 +31,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
|
|
@ -41,8 +40,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
Daytona,
|
||||
CreateSandboxFromImageParams,
|
||||
Daytona,
|
||||
DaytonaError,
|
||||
Resources,
|
||||
SandboxState,
|
||||
|
|
@ -73,13 +72,11 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
try:
|
||||
self._sandbox = self._daytona.find_one(labels=labels)
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
except DaytonaError:
|
||||
self._sandbox = None
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
|
||||
task_id, e)
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
|
|
@ -92,8 +89,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
resources=resources,
|
||||
)
|
||||
)
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
|
||||
# Resolve cwd: detect actual home dir inside the sandbox
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
|
|
@ -112,7 +108,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
|
|
@ -130,7 +126,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
timed_command,
|
||||
cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
|
|
@ -169,9 +166,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
|
||||
|
|
@ -189,6 +186,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
|
|
@ -210,8 +208,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
|||
try:
|
||||
if self._persistent:
|
||||
self._sandbox.stop()
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
|
||||
self._sandbox.id)
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id)
|
||||
else:
|
||||
self._daytona.delete(self._sandbox)
|
||||
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import subprocess
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -19,7 +18,6 @@ from tools.interrupt import is_interrupted
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# Security flags applied to every container.
|
||||
# The container itself is the security boundary (isolated from host).
|
||||
# We drop all capabilities then add back the minimum needed:
|
||||
|
|
@ -28,19 +26,28 @@ logger = logging.getLogger(__name__)
|
|||
# Block privilege escalation and limit PIDs.
|
||||
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
|
||||
_SECURITY_ARGS = [
|
||||
"--cap-drop", "ALL",
|
||||
"--cap-add", "DAC_OVERRIDE",
|
||||
"--cap-add", "CHOWN",
|
||||
"--cap-add", "FOWNER",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
|
||||
"--cap-drop",
|
||||
"ALL",
|
||||
"--cap-add",
|
||||
"DAC_OVERRIDE",
|
||||
"--cap-add",
|
||||
"CHOWN",
|
||||
"--cap-add",
|
||||
"FOWNER",
|
||||
"--security-opt",
|
||||
"no-new-privileges",
|
||||
"--pids-limit",
|
||||
"256",
|
||||
"--tmpfs",
|
||||
"/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs",
|
||||
"/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs",
|
||||
"/run:rw,noexec,nosuid,size=64m",
|
||||
]
|
||||
|
||||
|
||||
_storage_opt_ok: Optional[bool] = None # cached result across instances
|
||||
_storage_opt_ok: bool | None = None # cached result across instances
|
||||
|
||||
|
||||
class DockerEnvironment(BaseEnvironment):
|
||||
|
|
@ -74,7 +81,7 @@ class DockerEnvironment(BaseEnvironment):
|
|||
self._base_image = image
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._container_id: Optional[str] = None
|
||||
self._container_id: str | None = None
|
||||
logger.info(f"DockerEnvironment volumes: {volumes}")
|
||||
# Ensure volumes is a list (config.yaml could be malformed)
|
||||
if volumes is not None and not isinstance(volumes, list):
|
||||
|
|
@ -105,8 +112,8 @@ class DockerEnvironment(BaseEnvironment):
|
|||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
self._workspace_dir: Optional[str] = None
|
||||
self._home_dir: Optional[str] = None
|
||||
self._workspace_dir: str | None = None
|
||||
self._home_dir: str | None = None
|
||||
if self._persistent:
|
||||
sandbox = get_sandbox_dir() / "docker" / task_id
|
||||
self._workspace_dir = str(sandbox / "workspace")
|
||||
|
|
@ -114,14 +121,19 @@ class DockerEnvironment(BaseEnvironment):
|
|||
os.makedirs(self._workspace_dir, exist_ok=True)
|
||||
os.makedirs(self._home_dir, exist_ok=True)
|
||||
writable_args = [
|
||||
"-v", f"{self._workspace_dir}:/workspace",
|
||||
"-v", f"{self._home_dir}:/root",
|
||||
"-v",
|
||||
f"{self._workspace_dir}:/workspace",
|
||||
"-v",
|
||||
f"{self._home_dir}:/root",
|
||||
]
|
||||
else:
|
||||
writable_args = [
|
||||
"--tmpfs", "/workspace:rw,exec,size=10g",
|
||||
"--tmpfs", "/home:rw,exec,size=1g",
|
||||
"--tmpfs", "/root:rw,exec,size=1g",
|
||||
"--tmpfs",
|
||||
"/workspace:rw,exec,size=10g",
|
||||
"--tmpfs",
|
||||
"/home:rw,exec,size=1g",
|
||||
"--tmpfs",
|
||||
"/root:rw,exec,size=1g",
|
||||
]
|
||||
|
||||
# All containers get security hardening (capabilities dropped, no privilege
|
||||
|
|
@ -129,7 +141,7 @@ class DockerEnvironment(BaseEnvironment):
|
|||
# can install packages as needed.
|
||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||
volume_args = []
|
||||
for vol in (volumes or []):
|
||||
for vol in volumes or []:
|
||||
if not isinstance(vol, str):
|
||||
logger.warning(f"Docker volume entry is not a string: {vol!r}")
|
||||
continue
|
||||
|
|
@ -146,7 +158,9 @@ class DockerEnvironment(BaseEnvironment):
|
|||
logger.info(f"Docker run_args: {all_run_args}")
|
||||
|
||||
self._inner = _Docker(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
run_args=all_run_args,
|
||||
)
|
||||
self._container_id = self._inner.container_id
|
||||
|
|
@ -154,7 +168,7 @@ class DockerEnvironment(BaseEnvironment):
|
|||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
|
||||
|
||||
Only overlay2 on XFS with pquota supports per-container disk quotas.
|
||||
Ubuntu (and most distros) default to ext4, where this flag errors out.
|
||||
"""
|
||||
|
|
@ -164,7 +178,9 @@ class DockerEnvironment(BaseEnvironment):
|
|||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info", "--format", "{{.Driver}}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
driver = result.stdout.strip().lower()
|
||||
if driver != "overlay2":
|
||||
|
|
@ -174,14 +190,15 @@ class DockerEnvironment(BaseEnvironment):
|
|||
# Probe by attempting a dry-ish run — the fastest reliable check.
|
||||
probe = subprocess.run(
|
||||
["docker", "create", "--storage-opt", "size=1m", "hello-world"],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if probe.returncode == 0:
|
||||
# Clean up the created container
|
||||
container_id = probe.stdout.strip()
|
||||
if container_id:
|
||||
subprocess.run(["docker", "rm", container_id],
|
||||
capture_output=True, timeout=5)
|
||||
subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5)
|
||||
_storage_opt_ok = True
|
||||
else:
|
||||
_storage_opt_ok = False
|
||||
|
|
@ -190,9 +207,9 @@ class DockerEnvironment(BaseEnvironment):
|
|||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
exec_command = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
|
@ -218,7 +235,8 @@ class DockerEnvironment(BaseEnvironment):
|
|||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
|
@ -269,6 +287,7 @@ class DockerEnvironment(BaseEnvironment):
|
|||
|
||||
if not self._persistent:
|
||||
import shutil
|
||||
|
||||
for d in (self._workspace_dir, self._home_dir):
|
||||
if d:
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
|
|
|||
|
|
@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment):
|
|||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
|
|
@ -172,11 +172,7 @@ class LocalEnvironment(BaseEnvironment):
|
|||
# Wrap with output fences so we can later extract the real
|
||||
# command output and discard shell init/exit noise.
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc"
|
||||
)
|
||||
# Ensure PATH always includes standard dirs — systemd services
|
||||
# and some terminal multiplexers inherit a minimal PATH.
|
||||
|
|
@ -200,12 +196,14 @@ class LocalEnvironment(BaseEnvironment):
|
|||
)
|
||||
|
||||
if stdin_data is not None:
|
||||
|
||||
def _write_stdin():
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
|
|
|||
|
|
@ -8,10 +8,9 @@ project files, and config changes survive across sessions.
|
|||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
|
|
@ -31,7 +30,7 @@ def _load_snapshots() -> Dict[str, str]:
|
|||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
|
@ -52,7 +51,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
image: str,
|
||||
cwd: str = "~",
|
||||
timeout: int = 60,
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
modal_sandbox_kwargs: dict[str, Any] | None = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
|
|
@ -61,6 +60,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
if not ModalEnvironment._patches_applied:
|
||||
try:
|
||||
from environments.patches import apply_patches
|
||||
|
||||
apply_patches()
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -79,6 +79,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
if snapshot_id:
|
||||
try:
|
||||
import modal
|
||||
|
||||
restored_image = modal.Image.from_id(snapshot_id)
|
||||
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
|
||||
except Exception as e:
|
||||
|
|
@ -88,6 +89,7 @@ class ModalEnvironment(BaseEnvironment):
|
|||
effective_image = restored_image if restored_image else image
|
||||
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
|
||||
self._inner = SwerexModalEnvironment(
|
||||
image=effective_image,
|
||||
cwd=cwd,
|
||||
|
|
@ -97,9 +99,9 @@ class ModalEnvironment(BaseEnvironment):
|
|||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
|
|
@ -139,29 +141,29 @@ class ModalEnvironment(BaseEnvironment):
|
|||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._persistent:
|
||||
try:
|
||||
sandbox = getattr(self._inner, 'deployment', None)
|
||||
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
|
||||
sandbox = getattr(self._inner, "deployment", None)
|
||||
sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None
|
||||
if sandbox:
|
||||
import asyncio
|
||||
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
|
||||
try:
|
||||
snapshot_id = asyncio.run(_snapshot())
|
||||
except RuntimeError:
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
snapshot_id = pool.submit(
|
||||
asyncio.run, _snapshot()
|
||||
).result(timeout=60)
|
||||
snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60)
|
||||
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
if hasattr(self._inner, 'stop'):
|
||||
if hasattr(self._inner, "stop"):
|
||||
self._inner.stop()
|
||||
|
|
|
|||
|
|
@ -10,11 +10,9 @@ import logging
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
|
|
@ -33,7 +31,7 @@ def _load_snapshots() -> Dict[str, str]:
|
|||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
|
@ -42,6 +40,7 @@ def _save_snapshots(data: Dict[str, str]) -> None:
|
|||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
|
|
@ -58,6 +57,7 @@ def _get_scratch_dir() -> Path:
|
|||
return scratch_path
|
||||
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
sandbox = get_sandbox_dir() / "singularity"
|
||||
|
||||
scratch = Path("/scratch")
|
||||
|
|
@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
|||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
if image.endswith(".sif") and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith('docker://'):
|
||||
if not image.startswith("docker://"):
|
||||
return image
|
||||
|
||||
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
|
||||
image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-")
|
||||
cache_dir = _get_apptainer_cache_dir()
|
||||
sif_path = cache_dir / f"{image_name}.sif"
|
||||
|
||||
|
|
@ -123,7 +123,10 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
|||
try:
|
||||
result = subprocess.run(
|
||||
[executable, "build", str(sif_path), image],
|
||||
capture_output=True, text=True, timeout=600, env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("SIF build failed, falling back to docker:// URL")
|
||||
|
|
@ -145,6 +148,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
|||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
|
|
@ -174,7 +178,7 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
self._instance_started = False
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
self._overlay_dir: Path | None = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
|
|
@ -215,14 +219,13 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
|
||||
|
|
@ -235,16 +238,16 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
exec_command = f"cd {work_dir} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
|
@ -295,7 +298,9 @@ class SingularityEnvironment(BaseEnvironment):
|
|||
try:
|
||||
subprocess.run(
|
||||
[self.executable, "instance", "stop", self.instance_id],
|
||||
capture_output=True, text=True, timeout=30,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
logger.info("Singularity instance %s stopped", self.instance_id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||
and a remote kill is attempted over the ControlMaster socket.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
|
|
@ -65,12 +64,12 @@ class SSHEnvironment(BaseEnvironment):
|
|||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command = self._prepare_command(command)
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
wrapped = f"cd {work_dir} && {exec_command}"
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
|
|
@ -136,8 +135,7 @@ class SSHEnvironment(BaseEnvironment):
|
|||
def cleanup(self):
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
"-O", "exit", f"{self.user}@{self.host}"]
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"]
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -3,11 +3,10 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
from agent.redact import redact_sensitive_text
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,14 +24,19 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
|||
Thread-safe: uses the same per-task creation locks as terminal_tool to
|
||||
prevent duplicate sandbox creation from concurrent tool calls.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_check_disk_usage_warning,
|
||||
_creation_locks, _creation_locks_lock,
|
||||
)
|
||||
import time
|
||||
|
||||
from tools.terminal_tool import (
|
||||
_active_environments,
|
||||
_create_environment,
|
||||
_creation_locks,
|
||||
_creation_locks_lock,
|
||||
_env_lock,
|
||||
_get_env_config,
|
||||
_last_activity,
|
||||
_start_cleanup_thread,
|
||||
)
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
with _file_ops_lock:
|
||||
|
|
@ -143,17 +147,23 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
|||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
new_string: str = None, replace_all: bool = False, patch: str = None,
|
||||
task_id: str = "default") -> str:
|
||||
def patch_tool(
|
||||
mode: str = "replace",
|
||||
path: str = None,
|
||||
old_string: str = None,
|
||||
new_string: str = None,
|
||||
replace_all: bool = False,
|
||||
patch: str = None,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
"""Patch a file using replace mode or V4A patch format."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
|
||||
|
||||
if mode == "replace":
|
||||
if not path:
|
||||
return json.dumps({"error": "path required"})
|
||||
|
|
@ -166,7 +176,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when old_string not found — saves iterations where the agent
|
||||
|
|
@ -178,20 +188,33 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
file_glob: str = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0,
|
||||
task_id: str = "default") -> str:
|
||||
def search_tool(
|
||||
pattern: str,
|
||||
target: str = "content",
|
||||
path: str = ".",
|
||||
file_glob: str = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
output_mode: str = "content",
|
||||
context: int = 0,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
"""Search for content or files."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.search(
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
target=target,
|
||||
file_glob=file_glob,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
output_mode=output_mode,
|
||||
context=context,
|
||||
)
|
||||
if hasattr(result, 'matches'):
|
||||
if hasattr(result, "matches"):
|
||||
for m in result.matches:
|
||||
if hasattr(m, 'content') and m.content:
|
||||
if hasattr(m, "content") and m.content:
|
||||
m.content = redact_sensitive_text(m.content)
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
|
|
@ -209,7 +232,7 @@ FILE_TOOLS = [
|
|||
{"name": "read_file", "function": read_file_tool},
|
||||
{"name": "write_file", "function": write_file_tool},
|
||||
{"name": "patch", "function": patch_tool},
|
||||
{"name": "search_files", "function": search_tool}
|
||||
{"name": "search_files", "function": search_tool},
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -227,8 +250,10 @@ from tools.registry import registry
|
|||
def _check_file_reqs():
|
||||
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
|
||||
from tools import check_file_requirements
|
||||
|
||||
return check_file_requirements()
|
||||
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
|
|
@ -236,11 +261,21 @@ READ_FILE_SCHEMA = {
|
|||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
|
||||
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
|
||||
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default: 1)",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default: 500, max: 2000)",
|
||||
"default": 500,
|
||||
"maximum": 2000,
|
||||
},
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
"required": ["path"],
|
||||
},
|
||||
}
|
||||
|
||||
WRITE_FILE_SCHEMA = {
|
||||
|
|
@ -249,11 +284,14 @@ WRITE_FILE_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"}
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)",
|
||||
},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"},
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
}
|
||||
|
||||
PATCH_SCHEMA = {
|
||||
|
|
@ -262,15 +300,33 @@ PATCH_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "patch"],
|
||||
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
|
||||
"default": "replace",
|
||||
},
|
||||
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
|
||||
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
|
||||
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
|
||||
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
|
||||
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.",
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.",
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
|
||||
"default": False,
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch",
|
||||
},
|
||||
},
|
||||
"required": ["mode"]
|
||||
}
|
||||
"required": ["mode"],
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_FILES_SCHEMA = {
|
||||
|
|
@ -279,23 +335,57 @@ SEARCH_FILES_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
|
||||
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
|
||||
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
|
||||
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
|
||||
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
|
||||
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
|
||||
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
|
||||
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files"],
|
||||
"description": "'content' searches inside file contents, 'files' searches for files by name",
|
||||
"default": "content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current working directory)",
|
||||
"default": ".",
|
||||
},
|
||||
"file_glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 50)",
|
||||
"default": 50,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip first N results for pagination (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_only", "count"],
|
||||
"description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
|
||||
"default": "content",
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of context lines before and after each match (grep mode only)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
"required": ["pattern"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_read_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
|
||||
return read_file_tool(
|
||||
path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid
|
||||
)
|
||||
|
||||
|
||||
def _handle_write_file(args, **kw):
|
||||
|
|
@ -306,9 +396,14 @@ def _handle_write_file(args, **kw):
|
|||
def _handle_patch(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return patch_tool(
|
||||
mode=args.get("mode", "replace"), path=args.get("path"),
|
||||
old_string=args.get("old_string"), new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
|
||||
mode=args.get("mode", "replace"),
|
||||
path=args.get("path"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False),
|
||||
patch=args.get("patch"),
|
||||
task_id=tid,
|
||||
)
|
||||
|
||||
|
||||
def _handle_search_files(args, **kw):
|
||||
|
|
@ -317,12 +412,29 @@ def _handle_search_files(args, **kw):
|
|||
raw_target = args.get("target", "content")
|
||||
target = target_map.get(raw_target, raw_target)
|
||||
return search_tool(
|
||||
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
|
||||
pattern=args.get("pattern", ""),
|
||||
target=target,
|
||||
path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"),
|
||||
limit=args.get("limit", 50),
|
||||
offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"),
|
||||
context=args.get("context", 0),
|
||||
task_id=tid,
|
||||
)
|
||||
|
||||
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs)
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs)
|
||||
registry.register(
|
||||
name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(
|
||||
name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs)
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs)
|
||||
registry.register(
|
||||
name="search_files",
|
||||
toolset="file",
|
||||
schema=SEARCH_FILES_SCHEMA,
|
||||
handler=_handle_search_files,
|
||||
check_fn=_check_file_reqs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode):
|
|||
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
|
|
@ -29,21 +29,22 @@ Usage:
|
|||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from collections.abc import Callable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
def fuzzy_find_and_replace(
|
||||
content: str, old_string: str, new_string: str, replace_all: bool = False
|
||||
) -> tuple[str, int, str | None]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
|
|
@ -51,12 +52,12 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
|||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
strategies: list[tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
|
|
@ -66,46 +67,50 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
|||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return content, 0, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
return (
|
||||
content,
|
||||
0,
|
||||
(
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str:
|
||||
"""
|
||||
Apply replacements at the given positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content
|
||||
matches: List of (start, end) positions to replace
|
||||
new_string: Replacement text
|
||||
|
||||
|
||||
Returns:
|
||||
Content with replacements applied
|
||||
"""
|
||||
# Sort matches by position (descending) to replace from end to start
|
||||
# This preserves positions of earlier matches
|
||||
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
|
||||
|
||||
|
||||
result = content
|
||||
for start, end in sorted_matches:
|
||||
result = result[:start] + new_string + result[end:]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -113,7 +118,8 @@ def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string
|
|||
# Matching Strategies
|
||||
# =============================================================================
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""Strategy 1: Exact string match."""
|
||||
matches = []
|
||||
start = 0
|
||||
|
|
@ -126,206 +132,201 @@ def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
|||
return matches
|
||||
|
||||
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 2: Match with line-by-line whitespace trimming.
|
||||
|
||||
|
||||
Strips leading/trailing whitespace from each line before matching.
|
||||
"""
|
||||
# Normalize pattern and content by trimming each line
|
||||
pattern_lines = [line.strip() for line in pattern.split('\n')]
|
||||
pattern_normalized = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
pattern_lines = [line.strip() for line in pattern.split("\n")]
|
||||
pattern_normalized = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
content_normalized_lines = [line.strip() for line in content_lines]
|
||||
|
||||
|
||||
# Build mapping from normalized positions back to original positions
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_normalized_lines,
|
||||
pattern, pattern_normalized
|
||||
)
|
||||
return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized)
|
||||
|
||||
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 3: Collapse multiple whitespace to single space.
|
||||
"""
|
||||
|
||||
def normalize(s):
|
||||
# Collapse multiple spaces/tabs to single space, preserve newlines
|
||||
return re.sub(r'[ \t]+', ' ', s)
|
||||
|
||||
return re.sub(r"[ \t]+", " ", s)
|
||||
|
||||
pattern_normalized = normalize(pattern)
|
||||
content_normalized = normalize(content)
|
||||
|
||||
|
||||
# Find in normalized, map back to original
|
||||
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
|
||||
|
||||
|
||||
if not matches_in_normalized:
|
||||
return []
|
||||
|
||||
|
||||
# Map positions back to original content
|
||||
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
|
||||
|
||||
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 4: Ignore indentation differences entirely.
|
||||
|
||||
|
||||
Strips all leading whitespace from lines before matching.
|
||||
"""
|
||||
|
||||
def strip_indent(s):
|
||||
return '\n'.join(line.lstrip() for line in s.split('\n'))
|
||||
|
||||
return "\n".join(line.lstrip() for line in s.split("\n"))
|
||||
|
||||
pattern_stripped = strip_indent(pattern)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
content_lines = content.split("\n")
|
||||
content_stripped_lines = [line.lstrip() for line in content_lines]
|
||||
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
|
||||
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_stripped_lines,
|
||||
pattern, '\n'.join(pattern_lines)
|
||||
)
|
||||
pattern_lines = [line.lstrip() for line in pattern.split("\n")]
|
||||
|
||||
return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines))
|
||||
|
||||
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 5: Convert escape sequences to actual characters.
|
||||
|
||||
|
||||
Handles \\n -> newline, \\t -> tab, etc.
|
||||
"""
|
||||
|
||||
def unescape(s):
|
||||
# Convert common escape sequences
|
||||
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
|
||||
|
||||
return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
|
||||
|
||||
pattern_unescaped = unescape(pattern)
|
||||
|
||||
|
||||
if pattern_unescaped == pattern:
|
||||
# No escapes to convert, skip this strategy
|
||||
return []
|
||||
|
||||
|
||||
return _strategy_exact(content, pattern_unescaped)
|
||||
|
||||
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 6: Trim whitespace from first and last lines only.
|
||||
|
||||
|
||||
Useful when the pattern boundaries have whitespace differences.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
pattern_lines = pattern.split("\n")
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
# Trim only first and last lines
|
||||
pattern_lines[0] = pattern_lines[0].strip()
|
||||
if len(pattern_lines) > 1:
|
||||
pattern_lines[-1] = pattern_lines[-1].strip()
|
||||
|
||||
modified_pattern = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
|
||||
modified_pattern = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
|
||||
# Search through content for matching block
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
# Trim first and last of this block
|
||||
check_lines = block_lines.copy()
|
||||
check_lines[0] = check_lines[0].strip()
|
||||
if len(check_lines) > 1:
|
||||
check_lines[-1] = check_lines[-1].strip()
|
||||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
|
||||
if "\n".join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
|
||||
|
||||
If first and last lines match exactly, accept middle with 70% similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
pattern_lines = pattern.split("\n")
|
||||
if len(pattern_lines) < 2:
|
||||
return [] # Need at least 2 lines for anchoring
|
||||
|
||||
|
||||
first_line = pattern_lines[0].strip()
|
||||
last_line = pattern_lines[-1].strip()
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
content_lines = content.split("\n")
|
||||
matches = []
|
||||
|
||||
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
# Check if first and last lines match
|
||||
if (content_lines[i].strip() == first_line and
|
||||
content_lines[i + pattern_line_count - 1].strip() == last_line):
|
||||
|
||||
if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line:
|
||||
# Check middle similarity
|
||||
if pattern_line_count <= 2:
|
||||
# Only first and last, they match
|
||||
similarity = 1.0
|
||||
else:
|
||||
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
|
||||
pattern_middle = '\n'.join(pattern_lines[1:-1])
|
||||
content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1])
|
||||
pattern_middle = "\n".join(pattern_lines[1:-1])
|
||||
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
|
||||
|
||||
|
||||
if similarity >= 0.70:
|
||||
# Calculate positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
content_lines = content.split('\n')
|
||||
|
||||
pattern_lines = pattern.split("\n")
|
||||
content_lines = content.split("\n")
|
||||
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
# Calculate line-by-line similarity
|
||||
high_similarity_count = 0
|
||||
for p_line, c_line in zip(pattern_lines, block_lines):
|
||||
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
|
||||
if sim >= 0.80:
|
||||
high_similarity_count += 1
|
||||
|
||||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
|
|
@ -333,74 +334,76 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
|
|||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
|
||||
def _find_normalized_matches(
|
||||
content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Find matches in normalized content and map back to original positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content string
|
||||
content_lines: Original content split by lines
|
||||
content_normalized_lines: Normalized content lines
|
||||
pattern: Original pattern
|
||||
pattern_normalized: Normalized pattern
|
||||
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions in the original content
|
||||
"""
|
||||
pattern_norm_lines = pattern_normalized.split('\n')
|
||||
pattern_norm_lines = pattern_normalized.split("\n")
|
||||
num_pattern_lines = len(pattern_norm_lines)
|
||||
|
||||
|
||||
matches = []
|
||||
|
||||
|
||||
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
|
||||
# Check if this block matches
|
||||
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
|
||||
|
||||
block = "\n".join(content_normalized_lines[i : i + num_pattern_lines])
|
||||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
|
||||
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1
|
||||
|
||||
# Handle case where end is past content
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
|
||||
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _map_normalized_positions(original: str, normalized: str,
|
||||
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
def _map_normalized_positions(
|
||||
original: str, normalized: str, normalized_matches: list[tuple[int, int]]
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Map positions from normalized string back to original.
|
||||
|
||||
|
||||
This is a best-effort mapping that works for whitespace normalization.
|
||||
"""
|
||||
if not normalized_matches:
|
||||
return []
|
||||
|
||||
|
||||
# Build character mapping from normalized to original
|
||||
orig_to_norm = [] # orig_to_norm[i] = position in normalized
|
||||
|
||||
|
||||
orig_idx = 0
|
||||
norm_idx = 0
|
||||
|
||||
|
||||
while orig_idx < len(original) and norm_idx < len(normalized):
|
||||
if original[orig_idx] == normalized[norm_idx]:
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
|
||||
elif original[orig_idx] in " \t" and normalized[norm_idx] == " ":
|
||||
# Original has space/tab, normalized collapsed to space
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
# Don't advance norm_idx yet - wait until all whitespace consumed
|
||||
if orig_idx < len(original) and original[orig_idx] not in ' \t':
|
||||
if orig_idx < len(original) and original[orig_idx] not in " \t":
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t':
|
||||
elif original[orig_idx] in " \t":
|
||||
# Extra whitespace in original
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
|
@ -408,21 +411,21 @@ def _map_normalized_positions(original: str, normalized: str,
|
|||
# Mismatch - shouldn't happen with our normalization
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Fill remaining
|
||||
while orig_idx < len(original):
|
||||
orig_to_norm.append(len(normalized))
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Reverse mapping: for each normalized position, find original range
|
||||
norm_to_orig_start = {}
|
||||
norm_to_orig_end = {}
|
||||
|
||||
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
norm_to_orig_end[norm_pos] = orig_pos
|
||||
|
||||
|
||||
# Map matches
|
||||
original_matches = []
|
||||
for norm_start, norm_end in normalized_matches:
|
||||
|
|
@ -432,17 +435,17 @@ def _map_normalized_positions(original: str, normalized: str,
|
|||
else:
|
||||
# Find nearest
|
||||
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
|
||||
|
||||
|
||||
# Find original end
|
||||
if norm_end - 1 in norm_to_orig_end:
|
||||
orig_end = norm_to_orig_end[norm_end - 1] + 1
|
||||
else:
|
||||
orig_end = orig_start + (norm_end - norm_start)
|
||||
|
||||
|
||||
# Expand to include trailing whitespace that was normalized
|
||||
while orig_end < len(original) and original[orig_end] in ' \t':
|
||||
while orig_end < len(original) and original[orig_end] in " \t":
|
||||
orig_end += 1
|
||||
|
||||
|
||||
original_matches.append((orig_start, min(orig_end, len(original))))
|
||||
|
||||
|
||||
return original_matches
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,23 +35,26 @@ def _get_config():
|
|||
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||
|
||||
# Service domains blocked for security -- these allow arbitrary code/command
|
||||
# execution on the HA host or enable SSRF attacks on the local network.
|
||||
# HA provides zero service-level access control; all safety must be in our layer.
|
||||
_BLOCKED_DOMAINS = frozenset({
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
})
|
||||
_BLOCKED_DOMAINS = frozenset(
|
||||
{
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_headers(token: str = "") -> Dict[str, str]:
|
||||
def _get_headers(token: str = "") -> dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
if not token:
|
||||
_, token = _get_config()
|
||||
|
|
@ -65,11 +68,12 @@ def _get_headers(token: str = "") -> Dict[str, str]:
|
|||
# Async helpers (called from sync handlers via run_until_complete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _filter_and_summarize(
|
||||
states: list,
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Filter raw HA states by domain/area and return a compact summary."""
|
||||
if domain:
|
||||
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
||||
|
|
@ -77,26 +81,29 @@ def _filter_and_summarize(
|
|||
if area:
|
||||
area_lower = area.lower()
|
||||
states = [
|
||||
s for s in states
|
||||
s
|
||||
for s in states
|
||||
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
||||
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
||||
]
|
||||
|
||||
entities = []
|
||||
for s in states:
|
||||
entities.append({
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
})
|
||||
entities.append(
|
||||
{
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"count": len(entities), "entities": entities}
|
||||
|
||||
|
||||
async def _async_list_entities(
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch entity states from HA and optionally filter by domain/area."""
|
||||
import aiohttp
|
||||
|
||||
|
|
@ -110,7 +117,7 @@ async def _async_list_entities(
|
|||
return _filter_and_summarize(states, domain, area)
|
||||
|
||||
|
||||
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
async def _async_get_state(entity_id: str) -> dict[str, Any]:
|
||||
"""Fetch detailed state of a single entity."""
|
||||
import aiohttp
|
||||
|
||||
|
|
@ -131,11 +138,11 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def _build_service_payload(
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the JSON payload for a HA service call."""
|
||||
payload: Dict[str, Any] = {}
|
||||
payload: dict[str, Any] = {}
|
||||
if data:
|
||||
payload.update(data)
|
||||
# entity_id parameter takes precedence over data["entity_id"]
|
||||
|
|
@ -148,15 +155,17 @@ def _parse_service_response(
|
|||
domain: str,
|
||||
service: str,
|
||||
result: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Parse HA service call response into a structured result."""
|
||||
affected = []
|
||||
if isinstance(result, list):
|
||||
for s in result:
|
||||
affected.append({
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
})
|
||||
affected.append(
|
||||
{
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -168,9 +177,9 @@ def _parse_service_response(
|
|||
async def _async_call_service(
|
||||
domain: str,
|
||||
service: str,
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call a Home Assistant service."""
|
||||
import aiohttp
|
||||
|
||||
|
|
@ -178,15 +187,17 @@ async def _async_call_service(
|
|||
url = f"{hass_url}/api/services/{domain}/{service}"
|
||||
payload = _build_service_payload(entity_id, data)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
url,
|
||||
headers=_get_headers(hass_token),
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
|
||||
return _parse_service_response(domain, service, result)
|
||||
|
||||
|
|
@ -195,6 +206,7 @@ async def _async_call_service(
|
|||
# Sync wrappers (handler signature: (args, **kw) -> str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync handler."""
|
||||
try:
|
||||
|
|
@ -205,6 +217,7 @@ def _run_async(coro):
|
|||
if loop and loop.is_running():
|
||||
# Already inside an event loop -- create a new thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=30)
|
||||
|
|
@ -247,10 +260,12 @@ def _handle_call_service(args: dict, **kw) -> str:
|
|||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
if domain in _BLOCKED_DOMAINS:
|
||||
return json.dumps({
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
}
|
||||
)
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||
|
|
@ -269,7 +284,8 @@ def _handle_call_service(args: dict, **kw) -> str:
|
|||
# List services
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
|
||||
"""Fetch available services from HA and optionally filter by domain."""
|
||||
import aiohttp
|
||||
|
||||
|
|
@ -290,13 +306,10 @@ async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
|||
d = svc_domain.get("domain", "")
|
||||
domain_services = {}
|
||||
for svc_name, svc_info in svc_domain.get("services", {}).items():
|
||||
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
fields = svc_info.get("fields", {})
|
||||
if fields:
|
||||
svc_entry["fields"] = {
|
||||
k: v.get("description", "") for k, v in fields.items()
|
||||
if isinstance(v, dict)
|
||||
}
|
||||
svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)}
|
||||
domain_services[svc_name] = svc_entry
|
||||
result.append({"domain": d, "services": domain_services})
|
||||
|
||||
|
|
@ -318,6 +331,7 @@ def _handle_list_services(args: dict, **kw) -> str:
|
|||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ha_available() -> bool:
|
||||
"""Tool is only available when HASS_TOKEN is set."""
|
||||
return bool(os.getenv("HASS_TOKEN"))
|
||||
|
|
@ -369,8 +383,7 @@ HA_GET_STATE_SCHEMA = {
|
|||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The entity ID to query (e.g. 'light.living_room', "
|
||||
"'climate.thermostat', 'sensor.temperature')."
|
||||
"The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')."
|
||||
),
|
||||
},
|
||||
},
|
||||
|
|
@ -392,8 +405,7 @@ HA_LIST_SERVICES_SCHEMA = {
|
|||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
|
||||
"Omit to list services for all domains."
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains."
|
||||
),
|
||||
},
|
||||
},
|
||||
|
|
@ -428,8 +440,7 @@ HA_CALL_SERVICE_SCHEMA = {
|
|||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Target entity ID (e.g. 'light.living_room'). "
|
||||
"Some services (like scene.turn_on) may not need this."
|
||||
"Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this."
|
||||
),
|
||||
},
|
||||
"data": {
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ HONCHO_TOOL_SCHEMA = {
|
|||
|
||||
# ── Tool handler ──
|
||||
|
||||
|
||||
def _handle_query_user_context(args: dict, **kw) -> str:
|
||||
"""Execute the Honcho context query."""
|
||||
query = args.get("query", "")
|
||||
|
|
@ -84,6 +85,7 @@ def _handle_query_user_context(args: dict, **kw) -> str:
|
|||
|
||||
# ── Availability check ──
|
||||
|
||||
|
||||
def _check_honcho_available() -> bool:
|
||||
"""Tool is only available when Honcho is active."""
|
||||
return _session_manager is not None and _session_key is not None
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
"""
|
||||
Image Generation Tools Module
|
||||
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||
|
||||
Available tools:
|
||||
|
|
@ -19,7 +19,7 @@ Features:
|
|||
Usage:
|
||||
from image_generation_tool import image_generate_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Generate and automatically upscale an image
|
||||
result = await image_generate_tool(
|
||||
prompt="A serene mountain landscape with cherry blossoms",
|
||||
|
|
@ -28,12 +28,14 @@ Usage:
|
|||
)
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import fal_client
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -51,11 +53,7 @@ ENABLE_SAFETY_CHECKER = False
|
|||
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
|
||||
|
||||
# Aspect ratio mapping - simplified choices for model to select
|
||||
ASPECT_RATIO_MAP = {
|
||||
"landscape": "landscape_16_9",
|
||||
"square": "square_hd",
|
||||
"portrait": "portrait_16_9"
|
||||
}
|
||||
ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"}
|
||||
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
|
||||
|
||||
# Configuration for automatic upscaling
|
||||
|
|
@ -70,9 +68,7 @@ UPSCALER_GUIDANCE_SCALE = 4
|
|||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||
|
||||
# Valid parameter values for validation based on FLUX 2 Pro documentation
|
||||
VALID_IMAGE_SIZES = [
|
||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||
]
|
||||
VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
|
||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||
|
||||
|
|
@ -80,16 +76,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
|||
|
||||
|
||||
def _validate_parameters(
|
||||
image_size: Union[str, Dict[str, int]],
|
||||
image_size: str | dict[str, int],
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_images: int,
|
||||
output_format: str,
|
||||
acceleration: str = "none"
|
||||
) -> Dict[str, Any]:
|
||||
acceleration: str = "none",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize image generation parameters for FLUX 2 Pro model.
|
||||
|
||||
|
||||
Args:
|
||||
image_size: Either a preset string or custom size dict
|
||||
num_inference_steps: Number of inference steps
|
||||
|
|
@ -97,15 +93,15 @@ def _validate_parameters(
|
|||
num_images: Number of images to generate
|
||||
output_format: Output format for images
|
||||
acceleration: Acceleration mode for generation speed
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Validated and normalized parameters
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
|
||||
# Validate image_size
|
||||
if isinstance(image_size, str):
|
||||
if image_size not in VALID_IMAGE_SIZES:
|
||||
|
|
@ -123,52 +119,52 @@ def _validate_parameters(
|
|||
validated["image_size"] = image_size
|
||||
else:
|
||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||
|
||||
|
||||
# Validate num_inference_steps
|
||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||
validated["num_inference_steps"] = num_inference_steps
|
||||
|
||||
|
||||
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
|
||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||
validated["guidance_scale"] = float(guidance_scale)
|
||||
|
||||
|
||||
# Validate num_images
|
||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||
raise ValueError("num_images must be an integer between 1 and 4")
|
||||
validated["num_images"] = num_images
|
||||
|
||||
|
||||
# Validate output_format
|
||||
if output_format not in VALID_OUTPUT_FORMATS:
|
||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||
validated["output_format"] = output_format
|
||||
|
||||
|
||||
# Validate acceleration
|
||||
if acceleration not in VALID_ACCELERATION_MODES:
|
||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||
validated["acceleration"] = acceleration
|
||||
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
||||
when called from threaded contexts (e.g. gateway thread pool).
|
||||
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Upscaling image with Clarity Upscaler...")
|
||||
|
||||
|
||||
# Prepare arguments for upscaler
|
||||
upscaler_arguments = {
|
||||
"image_url": image_url,
|
||||
|
|
@ -179,35 +175,36 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
|||
"resemblance": UPSCALER_RESEMBLANCE,
|
||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
|
||||
}
|
||||
|
||||
|
||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
||||
# between calls (gateway thread-pool pattern).
|
||||
handler = fal_client.submit(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments)
|
||||
|
||||
# Get the upscaled result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
|
||||
logger.info(
|
||||
"Image upscaled successfully to %sx%s",
|
||||
upscaled_image.get("width", "unknown"),
|
||||
upscaled_image.get("height", "unknown"),
|
||||
)
|
||||
return {
|
||||
"url": upscaled_image["url"],
|
||||
"width": upscaled_image.get("width", 0),
|
||||
"height": upscaled_image.get("height", 0),
|
||||
"upscaled": True,
|
||||
"upscale_factor": UPSCALER_FACTOR
|
||||
"upscale_factor": UPSCALER_FACTOR,
|
||||
}
|
||||
else:
|
||||
logger.error("Upscaler returned invalid response")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error upscaling image: %s", e)
|
||||
return None
|
||||
|
|
@ -220,16 +217,16 @@ def image_generate_tool(
|
|||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
num_images: int = DEFAULT_NUM_IMAGES,
|
||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
||||
when asyncio.run() destroys and recreates event loops between calls, which
|
||||
happens in the gateway's thread-pool pattern.
|
||||
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
|
||||
|
|
@ -238,7 +235,7 @@ def image_generate_tool(
|
|||
num_images (int): Number of images to generate (1-4, default: 1)
|
||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing minimal generation results:
|
||||
{
|
||||
|
|
@ -252,7 +249,7 @@ def image_generate_tool(
|
|||
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
|
||||
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
|
||||
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
|
|
@ -262,32 +259,32 @@ def image_generate_tool(
|
|||
"guidance_scale": guidance_scale,
|
||||
"num_images": num_images,
|
||||
"output_format": output_format,
|
||||
"seed": seed
|
||||
"seed": seed,
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"images_generated": 0,
|
||||
"generation_time": 0
|
||||
"generation_time": 0,
|
||||
}
|
||||
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
|
||||
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
|
||||
|
||||
# Validate other parameters
|
||||
validated_params = _validate_parameters(
|
||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
|
||||
)
|
||||
|
||||
|
||||
# Prepare arguments for FAL.ai FLUX 2 Pro API
|
||||
arguments = {
|
||||
"prompt": prompt.strip(),
|
||||
|
|
@ -298,51 +295,44 @@ def image_generate_tool(
|
|||
"output_format": validated_params["output_format"],
|
||||
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
|
||||
"safety_tolerance": SAFETY_TOLERANCE,
|
||||
"sync_mode": True # Use sync mode for immediate results
|
||||
"sync_mode": True, # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
|
||||
# Add seed if provided
|
||||
if seed is not None and isinstance(seed, int):
|
||||
arguments["seed"] = seed
|
||||
|
||||
|
||||
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
|
||||
logger.info(" Model: %s", DEFAULT_MODEL)
|
||||
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
|
||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
||||
|
||||
logger.info(" Steps: %s", validated_params["num_inference_steps"])
|
||||
logger.info(" Guidance: %s", validated_params["guidance_scale"])
|
||||
|
||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
||||
handler = fal_client.submit(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments)
|
||||
|
||||
# Get the result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
|
||||
# Process the response
|
||||
if not result or "images" not in result:
|
||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||
|
||||
|
||||
images = result.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("No images were generated")
|
||||
|
||||
|
||||
# Format image data and upscale images
|
||||
formatted_images = []
|
||||
for img in images:
|
||||
if isinstance(img, dict) and "url" in img:
|
||||
original_image = {
|
||||
"url": img["url"],
|
||||
"width": img.get("width", 0),
|
||||
"height": img.get("height", 0)
|
||||
}
|
||||
|
||||
original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
formatted_images.append(upscaled_image)
|
||||
|
|
@ -351,52 +341,48 @@ def image_generate_tool(
|
|||
logger.warning("Using original image as fallback")
|
||||
original_image["upscaled"] = False
|
||||
formatted_images.append(original_image)
|
||||
|
||||
|
||||
if not formatted_images:
|
||||
raise ValueError("No valid image URLs returned from API")
|
||||
|
||||
|
||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
|
||||
|
||||
logger.info(
|
||||
"Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count
|
||||
)
|
||||
|
||||
# Prepare successful response - minimal format
|
||||
response_data = {
|
||||
"success": True,
|
||||
"image": formatted_images[0]["url"] if formatted_images else None
|
||||
}
|
||||
|
||||
response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["images_generated"] = len(formatted_images)
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Error generating image: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Prepare error response - minimal format
|
||||
response_data = {
|
||||
"success": False,
|
||||
"image": None
|
||||
}
|
||||
|
||||
response_data = {"success": False, "image": None}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
"""
|
||||
Check if the FAL.ai API key is available in environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
|
|
@ -406,7 +392,7 @@ def check_fal_api_key() -> bool:
|
|||
def check_image_generation_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for image generation tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
|
|
@ -414,19 +400,20 @@ def check_image_generation_requirements() -> bool:
|
|||
# Check API key
|
||||
if not check_fal_api_key():
|
||||
return False
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
import fal_client
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
|
|
@ -439,10 +426,10 @@ if __name__ == "__main__":
|
|||
"""
|
||||
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_fal_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ FAL_KEY environment variable not set")
|
||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||
|
|
@ -450,27 +437,28 @@ if __name__ == "__main__":
|
|||
exit(1)
|
||||
else:
|
||||
print("✅ FAL.ai API key found")
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
try:
|
||||
import fal_client
|
||||
|
||||
print("✅ fal_client library available")
|
||||
except ImportError:
|
||||
print("❌ fal_client library not found")
|
||||
print("Please install: pip install fal-client")
|
||||
exit(1)
|
||||
|
||||
|
||||
print("🛠️ Image generation tools ready for use!")
|
||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from image_generation_tool import image_generate_tool")
|
||||
print(" import asyncio")
|
||||
|
|
@ -484,23 +472,23 @@ if __name__ == "__main__":
|
|||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nSupported image sizes:")
|
||||
for size in VALID_IMAGE_SIZES:
|
||||
print(f" - {size}")
|
||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||
|
||||
|
||||
print("\nAcceleration modes:")
|
||||
for mode in VALID_ACCELERATION_MODES:
|
||||
print(f" - {mode}")
|
||||
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||
|
|
@ -521,17 +509,17 @@ IMAGE_GENERATE_SCHEMA = {
|
|||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["landscape", "square", "portrait"],
|
||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||
"default": "landscape"
|
||||
}
|
||||
"default": "landscape",
|
||||
},
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
"required": ["prompt"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ import os
|
|||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False
|
|||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
_MCP_AVAILABLE = True
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
|
|
@ -108,6 +110,7 @@ try:
|
|||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
|
|
@ -118,27 +121,36 @@ except ImportError:
|
|||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_MAX_RECONNECT_RETRIES = 5
|
||||
_MAX_BACKOFF_SECONDS = 60
|
||||
|
||||
# Environment variables that are safe to pass to stdio subprocesses
|
||||
_SAFE_ENV_KEYS = frozenset({
|
||||
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
|
||||
})
|
||||
_SAFE_ENV_KEYS = frozenset(
|
||||
{
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"TERM",
|
||||
"SHELL",
|
||||
"TMPDIR",
|
||||
}
|
||||
)
|
||||
|
||||
# Regex for credential patterns to strip from error messages
|
||||
_CREDENTIAL_PATTERN = re.compile(
|
||||
r"(?:"
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r")",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
|
@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile(
|
|||
# Security helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_safe_env(user_env: Optional[dict]) -> dict:
|
||||
|
||||
def _build_safe_env(user_env: dict | None) -> dict:
|
||||
"""Build a filtered environment dict for stdio subprocesses.
|
||||
|
||||
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
|
||||
|
|
@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str:
|
|||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
|
|
@ -216,18 +230,22 @@ class SamplingHandler:
|
|||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
config.get("max_tool_rounds", 5),
|
||||
5,
|
||||
int,
|
||||
minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
str(config.get("log_level", "info")).lower(),
|
||||
logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._rate_timestamps: list[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
|
|
@ -245,7 +263,7 @@ class SamplingHandler:
|
|||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
def _resolve_model(self, preferences) -> str | None:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
|
|
@ -265,7 +283,7 @@ class SamplingHandler:
|
|||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
def _convert_messages(self, params) -> list[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
|
|
@ -273,37 +291,47 @@ class SamplingHandler:
|
|||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: List[dict] = []
|
||||
messages: list[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
blocks = (
|
||||
msg.content_as_list
|
||||
if hasattr(msg, "content_as_list")
|
||||
else (msg.content if isinstance(msg.content, list) else [msg.content])
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
tool_uses = [
|
||||
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
|
||||
]
|
||||
content_blocks = [
|
||||
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
|
||||
]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
}
|
||||
)
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
tc_list.append(
|
||||
{
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
}
|
||||
)
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
|
|
@ -320,10 +348,12 @@ class SamplingHandler:
|
|||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
|
|
@ -352,16 +382,13 @@ class SamplingHandler:
|
|||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
|
|
@ -372,25 +399,28 @@ class SamplingHandler:
|
|||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
|
||||
self.server_name,
|
||||
args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
content_blocks.append(
|
||||
ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name, response.model,
|
||||
self.server_name,
|
||||
response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
|
|
@ -410,7 +440,8 @@ class SamplingHandler:
|
|||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name, response.model,
|
||||
self.server_name,
|
||||
response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
|
|
@ -445,12 +476,12 @@ class SamplingHandler:
|
|||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name, self.max_rpm,
|
||||
self.server_name,
|
||||
self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
|
|
@ -458,6 +489,7 @@ class SamplingHandler:
|
|||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
|
|
@ -469,7 +501,8 @@ class SamplingHandler:
|
|||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name, resolved_model,
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
|
|
@ -515,7 +548,10 @@ class SamplingHandler:
|
|||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
max_tokens,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
|
|
@ -524,19 +560,15 @@ class SamplingHandler:
|
|||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
asyncio.to_thread(_sync_call),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
|
|
@ -546,11 +578,7 @@ class SamplingHandler:
|
|||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
|
|
@ -560,6 +588,7 @@ class SamplingHandler:
|
|||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPServerTask:
|
||||
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
||||
|
||||
|
|
@ -571,22 +600,29 @@ class MCPServerTask:
|
|||
"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"name",
|
||||
"session",
|
||||
"tool_timeout",
|
||||
"_task",
|
||||
"_ready",
|
||||
"_shutdown_event",
|
||||
"_tools",
|
||||
"_error",
|
||||
"_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Optional[Any] = None
|
||||
self.session: Any | None = None
|
||||
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: asyncio.Task | None = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._error: Exception | None = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._sampling: SamplingHandler | None = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
|
|
@ -599,9 +635,7 @@ class MCPServerTask:
|
|||
user_env = config.get("env")
|
||||
|
||||
if not command:
|
||||
raise ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
|
||||
|
||||
safe_env = _build_safe_env(user_env)
|
||||
server_params = StdioServerParameters(
|
||||
|
|
@ -650,11 +684,7 @@ class MCPServerTask:
|
|||
if self.session is None:
|
||||
return
|
||||
tools_result = await self.session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
|
||||
|
|
@ -704,24 +734,28 @@ class MCPServerTask:
|
|||
if self._shutdown_event.is_set():
|
||||
logger.debug(
|
||||
"MCP server '%s' disconnected during shutdown: %s",
|
||||
self.name, exc,
|
||||
self.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
retries += 1
|
||||
if retries > _MAX_RECONNECT_RETRIES:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed after %d reconnection attempts, "
|
||||
"giving up: %s",
|
||||
self.name, _MAX_RECONNECT_RETRIES, exc,
|
||||
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
|
||||
self.name,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"MCP server '%s' connection lost (attempt %d/%d), "
|
||||
"reconnecting in %.0fs: %s",
|
||||
self.name, retries, _MAX_RECONNECT_RETRIES,
|
||||
backoff, exc,
|
||||
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
|
||||
self.name,
|
||||
retries,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
backoff,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
||||
|
|
@ -745,7 +779,7 @@ class MCPServerTask:
|
|||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"MCP server '%s' shutdown timed out, cancelling task",
|
||||
self.name,
|
||||
|
|
@ -762,11 +796,11 @@ class MCPServerTask:
|
|||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
_servers: dict[str, MCPServerTask] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
_mcp_loop: asyncio.AbstractEventLoop | None = None
|
||||
_mcp_thread: threading.Thread | None = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
_lock = threading.Lock()
|
||||
|
|
@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
|||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
|
||||
def _load_mcp_config() -> dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
|
|
@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]:
|
|||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
|
|
@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]:
|
|||
# Server connection helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
|
|
@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
|||
# Handler / check-fn factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that calls an MCP tool via the background loop.
|
||||
|
||||
|
|
@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
|||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
for block in (result.content or []):
|
||||
for block in result.content or []:
|
||||
if hasattr(block, "text"):
|
||||
error_text += block.text
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
error_text or "MCP tool returned an error"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
|
||||
|
||||
# Collect text from content blocks
|
||||
parts: List[str] = []
|
||||
for block in (result.content or []):
|
||||
parts: list[str] = []
|
||||
for block in result.content or []:
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
|
|
@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
|||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
server_name, tool_name, exc,
|
||||
server_name,
|
||||
tool_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
|
|
@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
|||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_resources()
|
||||
resources = []
|
||||
for r in (result.resources if hasattr(result, "resources") else []):
|
||||
for r in result.resources if hasattr(result, "resources") else []:
|
||||
entry = {}
|
||||
if hasattr(r, "uri"):
|
||||
entry["uri"] = str(r.uri)
|
||||
|
|
@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
|||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||
"MCP %s/list_resources failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
|
|
@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
|||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
uri = args.get("uri")
|
||||
if not uri:
|
||||
|
|
@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
|||
async def _call():
|
||||
result = await server.session.read_resource(uri)
|
||||
# read_resource returns ReadResourceResult with .contents list
|
||||
parts: List[str] = []
|
||||
parts: list[str] = []
|
||||
contents = result.contents if hasattr(result, "contents") else []
|
||||
for block in contents:
|
||||
if hasattr(block, "text"):
|
||||
|
|
@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
|||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||
"MCP %s/read_resource failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
|
|
@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
|||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_prompts()
|
||||
prompts = []
|
||||
for p in (result.prompts if hasattr(result, "prompts") else []):
|
||||
for p in result.prompts if hasattr(result, "prompts") else []:
|
||||
entry = {}
|
||||
if hasattr(p, "name"):
|
||||
entry["name"] = p.name
|
||||
|
|
@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
|||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||
"MCP %s/list_prompts failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
|
|
@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
|||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
name = args.get("name")
|
||||
if not name:
|
||||
|
|
@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
|||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
# GetPromptResult has .messages list
|
||||
messages = []
|
||||
for msg in (result.messages if hasattr(result, "messages") else []):
|
||||
for msg in result.messages if hasattr(result, "messages") else []:
|
||||
entry = {}
|
||||
if hasattr(msg, "role"):
|
||||
entry["role"] = msg.role
|
||||
|
|
@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
|||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||
"MCP %s/get_prompt failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
|
|
@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str):
|
|||
# Discovery & registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
|
|
@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
|||
return {
|
||||
"name": prefixed_name,
|
||||
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
||||
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
|
||||
"parameters": mcp_tool.inputSchema
|
||||
if mcp_tool.inputSchema
|
||||
else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
def _build_utility_schemas(server_name: str) -> list[dict]:
|
||||
"""Build schemas for the MCP utility tools (resources & prompts).
|
||||
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
|
|
@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
|||
]
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
def _existing_tool_names() -> list[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
names: list[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
|
|
@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]:
|
|||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
|
|
@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
|||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: List[str] = []
|
||||
registered_names: list[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
||||
for mcp_tool in server._tools:
|
||||
|
|
@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
|||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
"MCP server '%s' (%s): registered %d tool(s): %s",
|
||||
name, transport_type, len(registered_names),
|
||||
name,
|
||||
transport_type,
|
||||
len(registered_names),
|
||||
", ".join(registered_names),
|
||||
)
|
||||
return registered_names
|
||||
|
|
@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
|||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
|
||||
def discover_mcp_tools() -> list[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
|
|
@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]:
|
|||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
all_tools: list[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
async def _discover_one(name: str, cfg: dict) -> list[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
|
||||
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
transport_type = "HTTP" if "url" in cfg else "stdio"
|
||||
|
|
@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]:
|
|||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name, exc,
|
||||
name,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
|
@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]:
|
|||
if all_tools:
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in all_tools:
|
||||
|
|
@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]:
|
|||
return _existing_tool_names()
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
def get_mcp_status() -> list[dict]:
|
||||
"""Return status of all configured MCP servers for banner display.
|
||||
|
||||
Returns a list of dicts with keys: name, transport, tools, connected.
|
||||
Includes both successfully connected servers and configured-but-failed ones.
|
||||
"""
|
||||
result: List[dict] = []
|
||||
result: list[dict] = []
|
||||
|
||||
# Get configured servers from config
|
||||
configured = _load_mcp_config()
|
||||
|
|
@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]:
|
|||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -1440,7 +1464,9 @@ def shutdown_mcp_servers():
|
|||
for server, result in zip(servers_snapshot, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(
|
||||
"Error closing MCP server '%s': %s", server.name, result,
|
||||
"Error closing MCP server '%s': %s",
|
||||
server.name,
|
||||
result,
|
||||
)
|
||||
with _lock:
|
||||
_servers.clear()
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import os
|
|||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -46,30 +46,38 @@ ENTRY_DELIMITER = "\n§\n"
|
|||
|
||||
_MEMORY_THREAT_PATTERNS = [
|
||||
# Prompt injection
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'you\s+are\s+now\s+', "role_hijack"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"you\s+are\s+now\s+", "role_hijack"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
# Exfiltration via curl/wget with secrets
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
|
||||
# Persistence via shell rc
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
|
||||
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
|
||||
(r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"),
|
||||
]
|
||||
|
||||
# Subset of invisible chars for injection detection
|
||||
_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
def _scan_memory_content(content: str) -> Optional[str]:
|
||||
def _scan_memory_content(content: str) -> str | None:
|
||||
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
|
||||
# Check invisible unicode
|
||||
for char in _INVISIBLE_CHARS:
|
||||
|
|
@ -96,12 +104,12 @@ class MemoryStore:
|
|||
"""
|
||||
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
|
||||
self.memory_entries: List[str] = []
|
||||
self.user_entries: List[str] = []
|
||||
self.memory_entries: list[str] = []
|
||||
self.user_entries: list[str] = []
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
# Frozen snapshot for system prompt -- set once at load_from_disk()
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
|
||||
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
|
||||
|
||||
def load_from_disk(self):
|
||||
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
|
||||
|
|
@ -129,12 +137,12 @@ class MemoryStore:
|
|||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
def _entries_for(self, target: str) -> list[str]:
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: List[str]):
|
||||
def _set_entries(self, target: str, entries: list[str]):
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
else:
|
||||
|
|
@ -151,7 +159,7 @@ class MemoryStore:
|
|||
return self.user_char_limit
|
||||
return self.memory_char_limit
|
||||
|
||||
def add(self, target: str, content: str) -> Dict[str, Any]:
|
||||
def add(self, target: str, content: str) -> dict[str, Any]:
|
||||
"""Append a new entry. Returns error if it would exceed the char limit."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
|
|
@ -192,7 +200,7 @@ class MemoryStore:
|
|||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
|
||||
"""Find entry containing old_text substring, replace it with new_content."""
|
||||
old_text = old_text.strip()
|
||||
new_content = new_content.strip()
|
||||
|
|
@ -247,7 +255,7 @@ class MemoryStore:
|
|||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
|
||||
def remove(self, target: str, old_text: str) -> dict[str, Any]:
|
||||
"""Remove the entry containing old_text substring."""
|
||||
old_text = old_text.strip()
|
||||
if not old_text:
|
||||
|
|
@ -278,7 +286,7 @@ class MemoryStore:
|
|||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
def format_for_system_prompt(self, target: str) -> Optional[str]:
|
||||
def format_for_system_prompt(self, target: str) -> str | None:
|
||||
"""
|
||||
Return the frozen snapshot for system prompt injection.
|
||||
|
||||
|
|
@ -293,7 +301,7 @@ class MemoryStore:
|
|||
|
||||
# -- Internal helpers --
|
||||
|
||||
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
|
||||
def _success_response(self, target: str, message: str = None) -> dict[str, Any]:
|
||||
entries = self._entries_for(target)
|
||||
current = self._char_count(target)
|
||||
limit = self._char_limit(target)
|
||||
|
|
@ -310,7 +318,7 @@ class MemoryStore:
|
|||
resp["message"] = message
|
||||
return resp
|
||||
|
||||
def _render_block(self, target: str, entries: List[str]) -> str:
|
||||
def _render_block(self, target: str, entries: list[str]) -> str:
|
||||
"""Render a system prompt block with header and usage indicator."""
|
||||
if not entries:
|
||||
return ""
|
||||
|
|
@ -329,7 +337,7 @@ class MemoryStore:
|
|||
return f"{separator}\n{header}\n{separator}\n{content}"
|
||||
|
||||
@staticmethod
|
||||
def _read_file(path: Path) -> List[str]:
|
||||
def _read_file(path: Path) -> list[str]:
|
||||
"""Read a memory file and split into entries.
|
||||
|
||||
No file locking needed: _write_file uses atomic rename, so readers
|
||||
|
|
@ -339,7 +347,7 @@ class MemoryStore:
|
|||
return []
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
if not raw.strip():
|
||||
|
|
@ -351,7 +359,7 @@ class MemoryStore:
|
|||
return [e for e in entries if e]
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, entries: List[str]):
|
||||
def _write_file(path: Path, entries: list[str]):
|
||||
"""Write entries to a memory file using atomic temp-file + rename.
|
||||
|
||||
Previous implementation used open("w") + flock, but "w" truncates the
|
||||
|
|
@ -362,9 +370,7 @@ class MemoryStore:
|
|||
content = ENTRY_DELIMITER.join(entries) if entries else ""
|
||||
try:
|
||||
# Write to temp file in same directory (same filesystem for atomic rename)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
|
||||
)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
|
@ -378,7 +384,7 @@ class MemoryStore:
|
|||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write memory file {path}: {e}")
|
||||
|
||||
|
||||
|
|
@ -387,7 +393,7 @@ def memory_tool(
|
|||
target: str = "memory",
|
||||
content: str = None,
|
||||
old_text: str = None,
|
||||
store: Optional[MemoryStore] = None,
|
||||
store: MemoryStore | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the memory tool. Dispatches to MemoryStore methods.
|
||||
|
|
@ -395,10 +401,15 @@ def memory_tool(
|
|||
Returns JSON string with results.
|
||||
"""
|
||||
if store is None:
|
||||
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "Memory is not available. It may be disabled in config or this environment."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
if target not in ("memory", "user"):
|
||||
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False
|
||||
)
|
||||
|
||||
if action == "add":
|
||||
if not content:
|
||||
|
|
@ -407,18 +418,26 @@ def memory_tool(
|
|||
|
||||
elif action == "replace":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
result = store.replace(target, old_text, content)
|
||||
|
||||
elif action == "remove":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False
|
||||
)
|
||||
result = store.remove(target, old_text)
|
||||
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False
|
||||
)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
|
@ -457,23 +476,16 @@ MEMORY_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "replace", "remove"],
|
||||
"description": "The action to perform."
|
||||
},
|
||||
"action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["memory", "user"],
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The entry content. Required for 'add' and 'replace'."
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile.",
|
||||
},
|
||||
"content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Short unique substring identifying the entry to replace or remove."
|
||||
"description": "Short unique substring identifying the entry to replace or remove.",
|
||||
},
|
||||
},
|
||||
"required": ["action", "target"],
|
||||
|
|
@ -493,10 +505,7 @@ registry.register(
|
|||
target=args.get("target", "memory"),
|
||||
content=args.get("content"),
|
||||
old_text=args.get("old_text"),
|
||||
store=kw.get("store")),
|
||||
store=kw.get("store"),
|
||||
),
|
||||
check_fn=check_memory_requirements,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,21 +38,27 @@ Configuration:
|
|||
Usage:
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Process a complex query
|
||||
result = await mixture_of_agents_tool(
|
||||
user_prompt="Solve this complex mathematical proof..."
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
|
||||
from typing import Any
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.openrouter_client import (
|
||||
check_api_key as check_openrouter_api_key,
|
||||
)
|
||||
from tools.openrouter_client import (
|
||||
get_async_client as _get_openrouter_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -60,9 +66,9 @@ logger = logging.getLogger(__name__)
|
|||
# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
|
||||
REFERENCE_MODELS = [
|
||||
"anthropic/claude-opus-4.5",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.2-pro",
|
||||
"deepseek/deepseek-v3.2"
|
||||
"deepseek/deepseek-v3.2",
|
||||
]
|
||||
|
||||
# Aggregator model - synthesizes reference responses into final output
|
||||
|
|
@ -83,18 +89,18 @@ Responses from models:"""
|
|||
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str:
|
||||
"""
|
||||
Construct the final system prompt for the aggregator including all model responses.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): Base system prompt for aggregation
|
||||
responses (List[str]): List of responses from reference models
|
||||
|
||||
|
||||
Returns:
|
||||
str: Complete system prompt with enumerated responses
|
||||
"""
|
||||
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||
response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)])
|
||||
return f"{system_prompt}\n\n{response_text}"
|
||||
|
||||
|
||||
|
|
@ -103,48 +109,43 @@ async def _run_reference_model_safe(
|
|||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 6
|
||||
max_retries: int = 6,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to use
|
||||
user_prompt (str): The user's query
|
||||
temperature (float): Sampling temperature for response generation
|
||||
max_tokens (int): Maximum tokens in response
|
||||
max_retries (int): Maximum number of retry attempts
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_prompt}],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not model.lower().startswith('gpt-'):
|
||||
if not model.lower().startswith("gpt-"):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("%s responded (%s characters)", model, len(content))
|
||||
return model, content, True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Log more detailed error information for debugging
|
||||
|
|
@ -154,7 +155,7 @@ async def _run_reference_model_safe(
|
|||
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
else:
|
||||
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
|
|
@ -167,60 +168,47 @@ async def _run_reference_model_safe(
|
|||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||
max_tokens: int = None
|
||||
system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the aggregator model to synthesize the final response.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): System prompt with all reference responses
|
||||
user_prompt (str): Original user query
|
||||
temperature (float): Focused temperature for consistent aggregation
|
||||
max_tokens (int): Maximum tokens in final response
|
||||
|
||||
|
||||
Returns:
|
||||
str: Synthesized final response
|
||||
"""
|
||||
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": AGGREGATOR_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||
if not AGGREGATOR_MODEL.lower().startswith("gpt-"):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("Aggregation complete (%s characters)", len(content))
|
||||
return content
|
||||
|
||||
|
||||
async def mixture_of_agents_tool(
|
||||
user_prompt: str,
|
||||
reference_models: Optional[List[str]] = None,
|
||||
aggregator_model: Optional[str] = None
|
||||
user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a complex query using the Mixture-of-Agents methodology.
|
||||
|
||||
|
||||
This tool leverages multiple frontier language models to collaboratively solve
|
||||
extremely difficult problems requiring intense reasoning. It's particularly
|
||||
effective for:
|
||||
|
|
@ -229,16 +217,16 @@ async def mixture_of_agents_tool(
|
|||
- Multi-step analytical reasoning tasks
|
||||
- Problems requiring diverse domain expertise
|
||||
- Tasks where single models show limitations
|
||||
|
||||
|
||||
The MoA approach uses a fixed 2-layer architecture:
|
||||
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||
|
||||
|
||||
Args:
|
||||
user_prompt (str): The complex query or problem to solve
|
||||
reference_models (Optional[List[str]]): Custom reference models to use
|
||||
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the MoA results with the following structure:
|
||||
{
|
||||
|
|
@ -250,12 +238,12 @@ async def mixture_of_agents_tool(
|
|||
},
|
||||
"processing_time": float
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If MoA processing fails or API key is not set
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
|
|
@ -263,7 +251,7 @@ async def mixture_of_agents_tool(
|
|||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
|
|
@ -272,161 +260,152 @@ async def mixture_of_agents_tool(
|
|||
"failed_models": [],
|
||||
"final_response_length": 0,
|
||||
"processing_time_seconds": 0,
|
||||
"models_used": {}
|
||||
"models_used": {},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Starting Mixture-of-Agents processing...")
|
||||
logger.info("Query: %s", user_prompt[:100])
|
||||
|
||||
|
||||
# Validate API key availability
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
||||
|
||||
|
||||
# Use provided models or defaults
|
||||
ref_models = reference_models or REFERENCE_MODELS
|
||||
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||
|
||||
|
||||
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
|
||||
|
||||
|
||||
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||
logger.info("Layer 1: Generating reference responses...")
|
||||
model_results = await asyncio.gather(*[
|
||||
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||
for model in ref_models
|
||||
])
|
||||
|
||||
model_results = await asyncio.gather(
|
||||
*[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models]
|
||||
)
|
||||
|
||||
# Separate successful and failed responses
|
||||
successful_responses = []
|
||||
failed_models = []
|
||||
|
||||
|
||||
for model_name, content, success in model_results:
|
||||
if success:
|
||||
successful_responses.append(content)
|
||||
else:
|
||||
failed_models.append(model_name)
|
||||
|
||||
|
||||
successful_count = len(successful_responses)
|
||||
failed_count = len(failed_models)
|
||||
|
||||
|
||||
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
|
||||
|
||||
|
||||
if failed_models:
|
||||
logger.warning("Failed models: %s", ', '.join(failed_models))
|
||||
|
||||
logger.warning("Failed models: %s", ", ".join(failed_models))
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
raise ValueError(
|
||||
f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses."
|
||||
)
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
logger.info("Layer 2: Synthesizing final response...")
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
successful_responses
|
||||
)
|
||||
|
||||
final_response = await _run_aggregator_model(
|
||||
aggregator_system_prompt,
|
||||
user_prompt,
|
||||
AGGREGATOR_TEMPERATURE
|
||||
)
|
||||
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses)
|
||||
|
||||
final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE)
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
logger.info("MoA processing completed in %.2f seconds", processing_time)
|
||||
|
||||
|
||||
# Prepare successful response (only final aggregated result, minimal fields)
|
||||
result = {
|
||||
"success": True,
|
||||
"response": final_response,
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
"models_used": {"reference_models": ref_models, "aggregator_model": agg_model},
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["final_response_length"] = len(final_response)
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Calculate processing time even for errors
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# Prepare error response (minimal fields)
|
||||
result = {
|
||||
"success": False,
|
||||
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||
"models_used": {
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
},
|
||||
"error": error_msg
|
||||
"error": error_msg,
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_moa_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for MoA tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_openrouter_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, List[str]]:
|
||||
def get_available_models() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get information about available models for MoA processing.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_models": [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL],
|
||||
}
|
||||
|
||||
|
||||
def get_moa_configuration() -> Dict[str, Any]:
|
||||
def get_moa_configuration() -> dict[str, Any]:
|
||||
"""
|
||||
Get the current MoA configuration settings.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||
"""
|
||||
|
|
@ -437,7 +416,7 @@ def get_moa_configuration() -> Dict[str, Any]:
|
|||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"total_reference_models": len(REFERENCE_MODELS),
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -447,10 +426,10 @@ if __name__ == "__main__":
|
|||
"""
|
||||
print("🤖 Mixture-of-Agents Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_openrouter_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ OPENROUTER_API_KEY environment variable not set")
|
||||
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
|
||||
|
|
@ -458,26 +437,26 @@ if __name__ == "__main__":
|
|||
exit(1)
|
||||
else:
|
||||
print("✅ OpenRouter API key found")
|
||||
|
||||
|
||||
print("🛠️ MoA tools ready for use!")
|
||||
|
||||
|
||||
# Show current configuration
|
||||
config = get_moa_configuration()
|
||||
print(f"\n⚙️ Current Configuration:")
|
||||
print("\n⚙️ Current Configuration:")
|
||||
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||
print(" import asyncio")
|
||||
|
|
@ -488,24 +467,26 @@ if __name__ == "__main__":
|
|||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nBest use cases:")
|
||||
print(" - Complex mathematical proofs and calculations")
|
||||
print(" - Advanced coding problems and algorithm design")
|
||||
print(" - Multi-step analytical reasoning tasks")
|
||||
print(" - Problems requiring diverse domain expertise")
|
||||
print(" - Tasks where single models show limitations")
|
||||
|
||||
|
||||
print("\nPerformance characteristics:")
|
||||
print(" - Higher latency due to multiple model calls")
|
||||
print(" - Significantly improved quality for complex tasks")
|
||||
print(" - Parallel processing for efficiency")
|
||||
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||
print(
|
||||
f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation"
|
||||
)
|
||||
print(" - Token-efficient: only returns final aggregated response")
|
||||
print(" - Resilient: continues with partial model failures")
|
||||
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - State-of-the-art results on challenging benchmarks")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export MOA_TOOLS_DEBUG=true")
|
||||
|
|
@ -526,11 +507,11 @@ MOA_SCHEMA = {
|
|||
"properties": {
|
||||
"user_prompt": {
|
||||
"type": "string",
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.",
|
||||
}
|
||||
},
|
||||
"required": ["user_prompt"]
|
||||
}
|
||||
"required": ["user_prompt"],
|
||||
},
|
||||
}
|
||||
|
||||
registry.register(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Shared OpenRouter API client for Hermes tools.
|
||||
|
||||
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
_get_summarizer_client() pattern previously copy-pasted across web_tools,
|
||||
vision_tools, mixture_of_agents_tool, and session_search_tool.
|
||||
"""
|
||||
|
|
@ -9,6 +9,7 @@ vision_tools, mixture_of_agents_tool, and session_search_tool.
|
|||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ V4A Format:
|
|||
|
||||
Usage:
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
|
||||
operations, error = parse_v4a_patch(patch_content)
|
||||
if error:
|
||||
print(f"Parse error: {error}")
|
||||
|
|
@ -30,8 +30,8 @@ Usage:
|
|||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
|
|
@ -44,6 +44,7 @@ class OperationType(Enum):
|
|||
@dataclass
|
||||
class HunkLine:
|
||||
"""A single line in a patch hunk."""
|
||||
|
||||
prefix: str # ' ', '-', or '+'
|
||||
content: str
|
||||
|
||||
|
|
@ -51,182 +52,174 @@ class HunkLine:
|
|||
@dataclass
|
||||
class Hunk:
|
||||
"""A group of changes within a file."""
|
||||
context_hint: Optional[str] = None
|
||||
lines: List[HunkLine] = field(default_factory=list)
|
||||
|
||||
context_hint: str | None = None
|
||||
lines: list[HunkLine] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchOperation:
|
||||
"""A single operation in a V4A patch."""
|
||||
|
||||
operation: OperationType
|
||||
file_path: str
|
||||
new_path: Optional[str] = None # For move operations
|
||||
hunks: List[Hunk] = field(default_factory=list)
|
||||
content: Optional[str] = None # For add file operations
|
||||
new_path: str | None = None # For move operations
|
||||
hunks: list[Hunk] = field(default_factory=list)
|
||||
content: str | None = None # For add file operations
|
||||
|
||||
|
||||
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
|
||||
def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]:
|
||||
"""
|
||||
Parse a V4A format patch.
|
||||
|
||||
|
||||
Args:
|
||||
patch_content: The patch text in V4A format
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (operations, error_message)
|
||||
- If successful: (list_of_operations, None)
|
||||
- If failed: ([], error_description)
|
||||
"""
|
||||
lines = patch_content.split('\n')
|
||||
operations: List[PatchOperation] = []
|
||||
|
||||
lines = patch_content.split("\n")
|
||||
operations: list[PatchOperation] = []
|
||||
|
||||
# Find patch boundaries
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if '*** Begin Patch' in line or '***Begin Patch' in line:
|
||||
if "*** Begin Patch" in line or "***Begin Patch" in line:
|
||||
start_idx = i
|
||||
elif '*** End Patch' in line or '***End Patch' in line:
|
||||
elif "*** End Patch" in line or "***End Patch" in line:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
|
||||
if start_idx is None:
|
||||
# Try to parse without explicit begin marker
|
||||
start_idx = -1
|
||||
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = len(lines)
|
||||
|
||||
|
||||
# Parse operations between boundaries
|
||||
i = start_idx + 1
|
||||
current_op: Optional[PatchOperation] = None
|
||||
current_hunk: Optional[Hunk] = None
|
||||
|
||||
current_op: PatchOperation | None = None
|
||||
current_hunk: Hunk | None = None
|
||||
|
||||
while i < end_idx:
|
||||
line = lines[i]
|
||||
|
||||
|
||||
# Check for file operation markers
|
||||
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
|
||||
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
|
||||
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
|
||||
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
|
||||
|
||||
update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line)
|
||||
add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line)
|
||||
delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line)
|
||||
move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line)
|
||||
|
||||
if update_match:
|
||||
# Save previous operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.UPDATE,
|
||||
file_path=update_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip())
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif add_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.ADD,
|
||||
file_path=add_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip())
|
||||
current_hunk = Hunk()
|
||||
|
||||
|
||||
elif delete_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.DELETE,
|
||||
file_path=delete_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip())
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif move_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.MOVE,
|
||||
file_path=move_match.group(1).strip(),
|
||||
new_path=move_match.group(2).strip()
|
||||
new_path=move_match.group(2).strip(),
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif line.startswith('@@'):
|
||||
|
||||
elif line.startswith("@@"):
|
||||
# Context hint / hunk marker
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
|
||||
|
||||
# Extract context hint
|
||||
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
|
||||
hint_match = re.match(r"@@\s*(.+?)\s*@@", line)
|
||||
hint = hint_match.group(1) if hint_match else None
|
||||
current_hunk = Hunk(context_hint=hint)
|
||||
|
||||
|
||||
elif current_op and line:
|
||||
# Parse hunk line
|
||||
if current_hunk is None:
|
||||
current_hunk = Hunk()
|
||||
|
||||
if line.startswith('+'):
|
||||
current_hunk.lines.append(HunkLine('+', line[1:]))
|
||||
elif line.startswith('-'):
|
||||
current_hunk.lines.append(HunkLine('-', line[1:]))
|
||||
elif line.startswith(' '):
|
||||
current_hunk.lines.append(HunkLine(' ', line[1:]))
|
||||
elif line.startswith('\\'):
|
||||
|
||||
if line.startswith("+"):
|
||||
current_hunk.lines.append(HunkLine("+", line[1:]))
|
||||
elif line.startswith("-"):
|
||||
current_hunk.lines.append(HunkLine("-", line[1:]))
|
||||
elif line.startswith(" "):
|
||||
current_hunk.lines.append(HunkLine(" ", line[1:]))
|
||||
elif line.startswith("\\"):
|
||||
# "\ No newline at end of file" marker - skip
|
||||
pass
|
||||
else:
|
||||
# Treat as context line (implicit space prefix)
|
||||
current_hunk.lines.append(HunkLine(' ', line))
|
||||
|
||||
current_hunk.lines.append(HunkLine(" ", line))
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
# Don't forget the last operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult":
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
|
|
@ -236,7 +229,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
|||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
|
|
@ -244,7 +237,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
|||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
|
|
@ -252,7 +245,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
|||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
|
|
@ -260,19 +253,19 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
|||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
if hasattr(file_ops, "_check_lint"):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
|
||||
combined_diff = "\n".join(all_diffs)
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
|
|
@ -281,123 +274,124 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
|||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error='; '.join(errors)
|
||||
error="; ".join(errors),
|
||||
)
|
||||
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None
|
||||
lint=lint_results if lint_results else None,
|
||||
)
|
||||
|
||||
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply an add file operation."""
|
||||
# Extract content from hunks (all + lines)
|
||||
content_lines = []
|
||||
for hunk in op.hunks:
|
||||
for line in hunk.lines:
|
||||
if line.prefix == '+':
|
||||
if line.prefix == "+":
|
||||
content_lines.append(line.content)
|
||||
|
||||
content = '\n'.join(content_lines)
|
||||
|
||||
|
||||
content = "\n".join(content_lines)
|
||||
|
||||
result = file_ops.write_file(op.file_path, content)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
|
||||
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
|
||||
diff += '\n'.join(f"+{line}" for line in content_lines)
|
||||
|
||||
diff += "\n".join(f"+{line}" for line in content_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
|
||||
# Delete directly via shell command using the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split('\n'):
|
||||
if '|' in line:
|
||||
for line in read_result.content.split("\n"):
|
||||
if "|" in line:
|
||||
# Line format: " 123|content"
|
||||
parts = line.split('|', 1)
|
||||
parts = line.split("|", 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
|
||||
current_content = "\n".join(current_lines)
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == ' ':
|
||||
if line.prefix == " ":
|
||||
search_lines.append(line.content)
|
||||
replace_lines.append(line.content)
|
||||
elif line.prefix == '-':
|
||||
elif line.prefix == "-":
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == '+':
|
||||
elif line.prefix == "+":
|
||||
replace_lines.append(line.content)
|
||||
|
||||
|
||||
if search_lines:
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
search_pattern = "\n".join(search_lines)
|
||||
replacement = "\n".join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
|
|
@ -408,31 +402,32 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
|||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if count > 0:
|
||||
new_content = new_content[:window_start] + window_new + new_content[window_end:]
|
||||
error = None
|
||||
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
if write_result.error:
|
||||
return False, write_result.error
|
||||
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}"
|
||||
tofile=f"b/{op.file_path}",
|
||||
)
|
||||
diff = ''.join(diff_lines)
|
||||
|
||||
diff = "".join(diff_lines)
|
||||
|
||||
return True, diff
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ import logging
|
|||
import os
|
||||
import platform
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
|
|
@ -42,10 +41,11 @@ import time
|
|||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from tools.environments.local import _find_shell
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from tools.environments.local import _find_shell
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -54,30 +54,31 @@ logger = logging.getLogger(__name__)
|
|||
CHECKPOINT_PATH = Path(os.path.expanduser("~/.hermes/processes.json"))
|
||||
|
||||
# Limits
|
||||
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessSession:
|
||||
"""A tracked background process with output buffering."""
|
||||
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
|
||||
command: str # Original command string
|
||||
task_id: str = "" # Task/sandbox isolation key
|
||||
session_key: str = "" # Gateway session key (for reset protection)
|
||||
pid: Optional[int] = None # OS process ID
|
||||
process: Optional[subprocess.Popen] = None # Popen handle (local only)
|
||||
env_ref: Any = None # Reference to the environment object
|
||||
cwd: Optional[str] = None # Working directory
|
||||
started_at: float = 0.0 # time.time() of spawn
|
||||
exited: bool = False # Whether the process has finished
|
||||
exit_code: Optional[int] = None # Exit code (None if still running)
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
|
||||
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
|
||||
command: str # Original command string
|
||||
task_id: str = "" # Task/sandbox isolation key
|
||||
session_key: str = "" # Gateway session key (for reset protection)
|
||||
pid: int | None = None # OS process ID
|
||||
process: subprocess.Popen | None = None # Popen handle (local only)
|
||||
env_ref: Any = None # Reference to the environment object
|
||||
cwd: str | None = None # Working directory
|
||||
started_at: float = 0.0 # time.time() of spawn
|
||||
exited: bool = False # Whether the process has finished
|
||||
exit_code: int | None = None # Exit code (None if still running)
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_reader_thread: threading.Thread | None = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
|
||||
|
||||
|
|
@ -100,12 +101,12 @@ class ProcessRegistry:
|
|||
)
|
||||
|
||||
def __init__(self):
|
||||
self._running: Dict[str, ProcessSession] = {}
|
||||
self._finished: Dict[str, ProcessSession] = {}
|
||||
self._running: dict[str, ProcessSession] = {}
|
||||
self._finished: dict[str, ProcessSession] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Side-channel for check_interval watchers (gateway reads after agent run)
|
||||
self.pending_watchers: List[Dict[str, Any]] = []
|
||||
self.pending_watchers: list[dict[str, Any]] = []
|
||||
|
||||
@staticmethod
|
||||
def _clean_shell_noise(text: str) -> str:
|
||||
|
|
@ -149,6 +150,7 @@ class ProcessRegistry:
|
|||
# Try PTY mode for interactive CLI tools
|
||||
try:
|
||||
import ptyprocess
|
||||
|
||||
user_shell = _find_shell()
|
||||
pty_env = os.environ | (env_vars or {})
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
|
|
@ -260,10 +262,7 @@ class ProcessRegistry:
|
|||
log_path = f"/tmp/hermes_bg_{session.id}.log"
|
||||
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
|
||||
quoted_command = shlex.quote(command)
|
||||
bg_command = (
|
||||
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
|
||||
f"echo $! > {pid_path} && cat {pid_path}"
|
||||
)
|
||||
bg_command = f"nohup bash -c {quoted_command} > {log_path} 2>&1 & echo $! > {pid_path} && cat {pid_path}"
|
||||
|
||||
try:
|
||||
result = env.execute(bg_command, timeout=timeout)
|
||||
|
|
@ -313,7 +312,7 @@ class ProcessRegistry:
|
|||
with session._lock:
|
||||
session.output_buffer += chunk
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
except Exception as e:
|
||||
logger.debug("Process stdout reader ended: %s", e)
|
||||
|
||||
|
|
@ -326,9 +325,7 @@ class ProcessRegistry:
|
|||
session.exit_code = session.process.returncode
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _env_poller_loop(
|
||||
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
|
||||
):
|
||||
def _env_poller_loop(self, session: ProcessSession, env: Any, log_path: str, pid_path: str):
|
||||
"""Background thread: poll a sandbox log file for non-local backends."""
|
||||
while not session.exited:
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
|
|
@ -340,7 +337,7 @@ class ProcessRegistry:
|
|||
with session._lock:
|
||||
session.output_buffer = new_output
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
|
||||
# Check if process is still running
|
||||
check = env.execute(
|
||||
|
|
@ -383,7 +380,7 @@ class ProcessRegistry:
|
|||
with session._lock:
|
||||
session.output_buffer += text
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
except EOFError:
|
||||
break
|
||||
except Exception:
|
||||
|
|
@ -397,7 +394,7 @@ class ProcessRegistry:
|
|||
except Exception as e:
|
||||
logger.debug("PTY wait timed out or failed: %s", e)
|
||||
session.exited = True
|
||||
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
|
||||
session.exit_code = pty.exitstatus if hasattr(pty, "exitstatus") else -1
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _move_to_finished(self, session: ProcessSession):
|
||||
|
|
@ -409,7 +406,7 @@ class ProcessRegistry:
|
|||
|
||||
# ----- Query Methods -----
|
||||
|
||||
def get(self, session_id: str) -> Optional[ProcessSession]:
|
||||
def get(self, session_id: str) -> ProcessSession | None:
|
||||
"""Get a session by ID (running or finished)."""
|
||||
with self._lock:
|
||||
return self._running.get(session_id) or self._finished.get(session_id)
|
||||
|
|
@ -454,7 +451,7 @@ class ProcessRegistry:
|
|||
if offset == 0 and limit > 0:
|
||||
selected = lines[-limit:]
|
||||
else:
|
||||
selected = lines[offset:offset + limit]
|
||||
selected = lines[offset : offset + limit]
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
|
|
@ -485,10 +482,7 @@ class ProcessRegistry:
|
|||
|
||||
if requested_timeout and requested_timeout > max_timeout:
|
||||
effective_timeout = max_timeout
|
||||
timeout_note = (
|
||||
f"Requested wait of {requested_timeout}s was clamped "
|
||||
f"to configured limit of {max_timeout}s"
|
||||
)
|
||||
timeout_note = f"Requested wait of {requested_timeout}s was clamped to configured limit of {max_timeout}s"
|
||||
else:
|
||||
effective_timeout = requested_timeout or max_timeout
|
||||
|
||||
|
|
@ -581,7 +575,7 @@ class ProcessRegistry:
|
|||
return {"status": "already_exited", "error": "Process has already finished"}
|
||||
|
||||
# PTY mode -- write through pty handle (expects bytes)
|
||||
if hasattr(session, '_pty') and session._pty:
|
||||
if hasattr(session, "_pty") and session._pty:
|
||||
try:
|
||||
pty_data = data.encode("utf-8") if isinstance(data, str) else data
|
||||
session._pty.write(pty_data)
|
||||
|
|
@ -635,26 +629,17 @@ class ProcessRegistry:
|
|||
def has_active_processes(self, task_id: str) -> bool:
|
||||
"""Check if there are active (running) processes for a task_id."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.task_id == task_id and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
return any(s.task_id == task_id and not s.exited for s in self._running.values())
|
||||
|
||||
def has_active_for_session(self, session_key: str) -> bool:
|
||||
"""Check if there are active processes for a gateway session key."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.session_key == session_key and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
return any(s.session_key == session_key and not s.exited for s in self._running.values())
|
||||
|
||||
def kill_all(self, task_id: str = None) -> int:
|
||||
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
|
||||
with self._lock:
|
||||
targets = [
|
||||
s for s in self._running.values()
|
||||
if (task_id is None or s.task_id == task_id) and not s.exited
|
||||
]
|
||||
targets = [s for s in self._running.values() if (task_id is None or s.task_id == task_id) and not s.exited]
|
||||
|
||||
killed = 0
|
||||
for session in targets:
|
||||
|
|
@ -669,10 +654,7 @@ class ProcessRegistry:
|
|||
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
|
||||
# First prune expired finished sessions
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in self._finished.items()
|
||||
if (now - s.started_at) > FINISHED_TTL_SECONDS
|
||||
]
|
||||
expired = [sid for sid, s in self._finished.items() if (now - s.started_at) > FINISHED_TTL_SECONDS]
|
||||
for sid in expired:
|
||||
del self._finished[sid]
|
||||
|
||||
|
|
@ -696,18 +678,21 @@ class ProcessRegistry:
|
|||
entries = []
|
||||
for s in self._running.values():
|
||||
if not s.exited:
|
||||
entries.append({
|
||||
"session_id": s.id,
|
||||
"command": s.command,
|
||||
"pid": s.pid,
|
||||
"cwd": s.cwd,
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
})
|
||||
|
||||
entries.append(
|
||||
{
|
||||
"session_id": s.id,
|
||||
"command": s.command,
|
||||
"pid": s.pid,
|
||||
"cwd": s.cwd,
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
}
|
||||
)
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
from utils import atomic_json_write
|
||||
|
||||
atomic_json_write(CHECKPOINT_PATH, entries)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
|
||||
|
|
@ -759,6 +744,7 @@ class ProcessRegistry:
|
|||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
|
||||
atomic_json_write(CHECKPOINT_PATH, [])
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
|
||||
|
|
@ -790,38 +776,32 @@ PROCESS_SCHEMA = {
|
|||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
|
||||
"description": "Action to perform on background processes"
|
||||
"description": "Action to perform on background processes",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
|
||||
"description": "Process session ID (from terminal background output). Required for all actions except 'list'.",
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
|
||||
"description": "Text to send to process stdin (for 'write' and 'submit' actions)",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
|
||||
"minimum": 1
|
||||
"minimum": 1,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line offset for 'log' action (default: last 200 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max lines to return for 'log' action",
|
||||
"minimum": 1
|
||||
}
|
||||
"offset": {"type": "integer", "description": "Line offset for 'log' action (default: last 200 lines)"},
|
||||
"limit": {"type": "integer", "description": "Max lines to return for 'log' action", "minimum": 1},
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_process(args, **kw):
|
||||
import json as _json
|
||||
|
||||
task_id = kw.get("task_id")
|
||||
action = args.get("action", "")
|
||||
# Coerce to string — some models send session_id as an integer
|
||||
|
|
@ -835,8 +815,10 @@ def _handle_process(args, **kw):
|
|||
if action == "poll":
|
||||
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
|
||||
elif action == "log":
|
||||
return _json.dumps(process_registry.read_log(
|
||||
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
|
||||
return _json.dumps(
|
||||
process_registry.read_log(session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
elif action == "wait":
|
||||
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
|
||||
elif action == "kill":
|
||||
|
|
@ -845,7 +827,10 @@ def _handle_process(args, **kw):
|
|||
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
elif action == "submit":
|
||||
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
|
||||
return _json.dumps(
|
||||
{"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
registry.register(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ Import chain (circular-import safe):
|
|||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,12 +25,17 @@ class ToolEntry:
|
|||
"""Metadata for a single registered tool."""
|
||||
|
||||
__slots__ = (
|
||||
"name", "toolset", "schema", "handler", "check_fn",
|
||||
"requires_env", "is_async", "description",
|
||||
"name",
|
||||
"toolset",
|
||||
"schema",
|
||||
"handler",
|
||||
"check_fn",
|
||||
"requires_env",
|
||||
"is_async",
|
||||
"description",
|
||||
)
|
||||
|
||||
def __init__(self, name, toolset, schema, handler, check_fn,
|
||||
requires_env, is_async, description):
|
||||
def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description):
|
||||
self.name = name
|
||||
self.toolset = toolset
|
||||
self.schema = schema
|
||||
|
|
@ -45,8 +50,8 @@ class ToolRegistry:
|
|||
"""Singleton registry that collects tool schemas + handlers from tool files."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
self._tools: dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: dict[str, Callable] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
|
|
@ -81,7 +86,7 @@ class ToolRegistry:
|
|||
# Schema retrieval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
|
||||
def get_definitions(self, tool_names: set[str], quiet: bool = False) -> list[dict]:
|
||||
"""Return OpenAI-format tool schemas for the requested tool names.
|
||||
|
||||
Only tools whose ``check_fn()`` returns True (or have no check_fn)
|
||||
|
|
@ -122,6 +127,7 @@ class ToolRegistry:
|
|||
try:
|
||||
if entry.is_async:
|
||||
from model_tools import _run_async
|
||||
|
||||
return _run_async(entry.handler(args, **kwargs))
|
||||
return entry.handler(args, **kwargs)
|
||||
except Exception as e:
|
||||
|
|
@ -132,16 +138,16 @@ class ToolRegistry:
|
|||
# Query helpers (replace redundant dicts in model_tools.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_all_tool_names(self) -> List[str]:
|
||||
def get_all_tool_names(self) -> list[str]:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
def get_toolset_for_tool(self, name: str) -> str | None:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
return entry.toolset if entry else None
|
||||
|
||||
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
||||
def get_tool_to_toolset_map(self) -> dict[str, str]:
|
||||
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
||||
return {name: e.toolset for name, e in self._tools.items()}
|
||||
|
||||
|
|
@ -160,14 +166,14 @@ class ToolRegistry:
|
|||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||
return False
|
||||
|
||||
def check_toolset_requirements(self) -> Dict[str, bool]:
|
||||
def check_toolset_requirements(self) -> dict[str, bool]:
|
||||
"""Return ``{toolset: available_bool}`` for every toolset."""
|
||||
toolsets = set(e.toolset for e in self._tools.values())
|
||||
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
|
||||
|
||||
def get_available_toolsets(self) -> Dict[str, dict]:
|
||||
def get_available_toolsets(self) -> dict[str, dict]:
|
||||
"""Return toolset metadata for UI display."""
|
||||
toolsets: Dict[str, dict] = {}
|
||||
toolsets: dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in toolsets:
|
||||
|
|
@ -184,9 +190,9 @@ class ToolRegistry:
|
|||
toolsets[ts]["requirements"].append(env)
|
||||
return toolsets
|
||||
|
||||
def get_toolset_requirements(self) -> Dict[str, dict]:
|
||||
def get_toolset_requirements(self) -> dict[str, dict]:
|
||||
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
||||
result: Dict[str, dict] = {}
|
||||
result: dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in result:
|
||||
|
|
@ -217,11 +223,13 @@ class ToolRegistry:
|
|||
if self.is_toolset_available(ts):
|
||||
available.append(ts)
|
||||
else:
|
||||
unavailable.append({
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
})
|
||||
unavailable.append(
|
||||
{
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
}
|
||||
)
|
||||
return available, unavailable
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -29,19 +29,16 @@ SEND_MESSAGE_SCHEMA = {
|
|||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "list"],
|
||||
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
|
||||
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message text to send"
|
||||
}
|
||||
"message": {"type": "string", "description": "The message text to send"},
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -59,6 +56,7 @@ def _handle_list():
|
|||
"""Return formatted list of available messaging targets."""
|
||||
try:
|
||||
from gateway.channel_directory import format_directory_for_display
|
||||
|
||||
return json.dumps({"targets": format_directory_for_display()})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load channel directory: {e}"})
|
||||
|
|
@ -79,26 +77,30 @@ def _handle_send(args):
|
|||
if chat_id and not chat_id.lstrip("-").isdigit():
|
||||
try:
|
||||
from gateway.channel_directory import resolve_channel_name
|
||||
|
||||
resolved = resolve_channel_name(platform_name, chat_id)
|
||||
if resolved:
|
||||
chat_id = resolved
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Use send_message(action='list') to see available targets."
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Use send_message(action='list') to see available targets."
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Try using a numeric channel ID instead."
|
||||
})
|
||||
return json.dumps(
|
||||
{"error": f"Could not resolve '{chat_id}' on {platform_name}. Try using a numeric channel ID instead."}
|
||||
)
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted"})
|
||||
|
||||
try:
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
from gateway.config import Platform, load_gateway_config
|
||||
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load gateway config: {e}"})
|
||||
|
|
@ -117,7 +119,11 @@ def _handle_send(args):
|
|||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."
|
||||
}
|
||||
)
|
||||
|
||||
used_home_channel = False
|
||||
if not chat_id:
|
||||
|
|
@ -126,14 +132,17 @@ def _handle_send(args):
|
|||
chat_id = home.chat_id
|
||||
used_home_channel = True
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"No home channel set for {platform_name} to determine where to send the message. "
|
||||
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
|
||||
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"No home channel set for {platform_name} to determine where to send the message. "
|
||||
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
|
||||
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
from model_tools import _run_async
|
||||
|
||||
result = _run_async(_send_to_platform(platform, pconfig, chat_id, message))
|
||||
if used_home_channel and isinstance(result, dict) and result.get("success"):
|
||||
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
|
||||
|
|
@ -142,6 +151,7 @@ def _handle_send(args):
|
|||
if isinstance(result, dict) and result.get("success"):
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
|
||||
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, message, source_label=source_label):
|
||||
result["mirrored"] = True
|
||||
|
|
@ -156,6 +166,7 @@ def _handle_send(args):
|
|||
async def _send_to_platform(platform, pconfig, chat_id, message):
|
||||
"""Route a message to the appropriate platform sender."""
|
||||
from gateway.config import Platform
|
||||
|
||||
if platform == Platform.TELEGRAM:
|
||||
return await _send_telegram(pconfig.token, chat_id, message)
|
||||
elif platform == Platform.DISCORD:
|
||||
|
|
@ -171,6 +182,7 @@ async def _send_telegram(token, chat_id, message):
|
|||
"""Send via Telegram Bot API (one-shot, no polling needed)."""
|
||||
try:
|
||||
from telegram import Bot
|
||||
|
||||
bot = Bot(token=token)
|
||||
msg = await bot.send_message(chat_id=int(chat_id), text=message)
|
||||
return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)}
|
||||
|
|
@ -189,7 +201,7 @@ async def _send_discord(token, chat_id, message):
|
|||
try:
|
||||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
|
||||
chunks = [message[i:i+2000] for i in range(0, len(message), 2000)]
|
||||
chunks = [message[i : i + 2000] for i in range(0, len(message), 2000)]
|
||||
message_ids = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for chunk in chunks:
|
||||
|
|
@ -266,6 +278,7 @@ def _check_send_message():
|
|||
return True
|
||||
try:
|
||||
from gateway.status import is_gateway_running
|
||||
|
||||
return is_gateway_running()
|
||||
except Exception:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -18,11 +18,8 @@ Flow:
|
|||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from typing import Any
|
||||
|
||||
from agent.auxiliary_client import get_async_text_auxiliary_client
|
||||
|
||||
|
|
@ -33,7 +30,7 @@ MAX_SESSION_CHARS = 100_000
|
|||
MAX_SUMMARY_TOKENS = 10000
|
||||
|
||||
|
||||
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
def _format_timestamp(ts: int | float | str | None) -> str:
|
||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||
|
||||
Returns "unknown" for None, str(ts) if conversion fails.
|
||||
|
|
@ -43,11 +40,13 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
|||
try:
|
||||
if isinstance(ts, (int, float)):
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
if isinstance(ts, str):
|
||||
if ts.replace(".", "").replace("-", "").isdigit():
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(float(ts))
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
return ts
|
||||
|
|
@ -59,7 +58,7 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
|||
return str(ts)
|
||||
|
||||
|
||||
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
|
||||
def _format_conversation(messages: list[dict[str, Any]]) -> str:
|
||||
"""Format session messages into a readable transcript for summarization."""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
|
|
@ -93,9 +92,7 @@ def _format_conversation(messages: List[Dict[str, Any]]) -> str:
|
|||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _truncate_around_matches(
|
||||
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
|
||||
) -> str:
|
||||
def _truncate_around_matches(full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS) -> str:
|
||||
"""
|
||||
Truncate a conversation transcript to max_chars, centered around
|
||||
where the query terms appear. Keeps content near matches, trims the edges.
|
||||
|
|
@ -129,9 +126,7 @@ def _truncate_around_matches(
|
|||
return prefix + truncated + suffix
|
||||
|
||||
|
||||
async def _summarize_session(
|
||||
conversation_text: str, query: str, session_meta: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
async def _summarize_session(conversation_text: str, query: str, session_meta: dict[str, Any]) -> str | None:
|
||||
"""Summarize a single session conversation focused on the search query."""
|
||||
system_prompt = (
|
||||
"You are reviewing a past conversation transcript to help recall what happened. "
|
||||
|
|
@ -163,7 +158,8 @@ async def _summarize_session(
|
|||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
|
||||
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
|
||||
|
||||
_extra = get_auxiliary_extra_body()
|
||||
response = await _async_aux_client.chat.completions.create(
|
||||
model=_SUMMARIZER_MODEL,
|
||||
|
|
@ -221,13 +217,16 @@ def session_search(
|
|||
)
|
||||
|
||||
if not raw_results:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Resolve child sessions to their parent — delegation stores detailed
|
||||
# content in child sessions, but the user's conversation is the parent.
|
||||
|
|
@ -283,12 +282,9 @@ def session_search(
|
|||
logging.warning(f"Failed to prepare session {session_id}: {e}")
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
async def _summarize_all() -> list[str | Exception]:
|
||||
"""Summarize all sessions in parallel."""
|
||||
coros = [
|
||||
_summarize_session(text, query, meta)
|
||||
for _, _, text, meta in tasks
|
||||
]
|
||||
coros = [_summarize_session(text, query, meta) for _, _, text, meta in tasks]
|
||||
return await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
||||
try:
|
||||
|
|
@ -300,10 +296,13 @@ def session_search(
|
|||
results = asyncio.run(_summarize_all())
|
||||
except concurrent.futures.TimeoutError:
|
||||
logging.warning("Session summarization timed out after 60 seconds")
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for (session_id, match_info, _, _), result in zip(tasks, results):
|
||||
|
|
@ -311,21 +310,26 @@ def session_search(
|
|||
logging.warning(f"Failed to summarize session {session_id}: {result}")
|
||||
continue
|
||||
if result:
|
||||
summaries.append({
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
})
|
||||
summaries.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
|
||||
|
|
@ -337,6 +341,7 @@ def check_session_search_requirements() -> bool:
|
|||
return False
|
||||
try:
|
||||
from hermes_state import DEFAULT_DB_PATH
|
||||
|
||||
return DEFAULT_DB_PATH.parent.exists()
|
||||
except ImportError:
|
||||
return False
|
||||
|
|
@ -356,7 +361,7 @@ SESSION_SEARCH_SCHEMA = {
|
|||
"Don't hesitate to search -- it's fast and cheap. Better to search and confirm "
|
||||
"than to guess or ask the user to repeat themselves.\n\n"
|
||||
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
|
||||
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
|
||||
'phrases for exact match ("docker networking"), boolean (python NOT java), prefix (deploy*). '
|
||||
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
|
||||
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
|
||||
"keyword searches in parallel. Returns summaries of the top matching sessions."
|
||||
|
|
@ -395,6 +400,7 @@ registry.register(
|
|||
role_filter=args.get("role_filter"),
|
||||
limit=args.get("limit", 3),
|
||||
db=kw.get("db"),
|
||||
current_session_id=kw.get("current_session_id")),
|
||||
current_session_id=kw.get("current_session_id"),
|
||||
),
|
||||
check_fn=check_session_search_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,20 +38,21 @@ import os
|
|||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
|
||||
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
|
||||
|
||||
_GUARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
_GUARD_AVAILABLE = False
|
||||
|
||||
|
||||
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
|
||||
def _security_scan_skill(skill_dir: Path) -> str | None:
|
||||
"""Scan a skill directory after write. Returns error string if blocked, else None."""
|
||||
if not _GUARD_AVAILABLE:
|
||||
return None
|
||||
|
|
@ -65,8 +66,8 @@ def _security_scan_skill(skill_dir: Path) -> Optional[str]:
|
|||
logger.warning("Security scan failed for %s: %s", skill_dir, e)
|
||||
return None
|
||||
|
||||
import yaml
|
||||
|
||||
import yaml
|
||||
|
||||
# All skills live in ~/.hermes/skills/ (single source of truth)
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
|
@ -76,7 +77,7 @@ MAX_NAME_LENGTH = 64
|
|||
MAX_DESCRIPTION_LENGTH = 1024
|
||||
|
||||
# Characters allowed in skill names (filesystem-safe, URL-friendly)
|
||||
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
|
||||
VALID_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9._-]*$")
|
||||
|
||||
# Subdirectories allowed for write_file/remove_file
|
||||
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
|
||||
|
|
@ -91,7 +92,8 @@ def check_skill_manage_requirements() -> bool:
|
|||
# Validation helpers
|
||||
# =============================================================================
|
||||
|
||||
def _validate_name(name: str) -> Optional[str]:
|
||||
|
||||
def _validate_name(name: str) -> str | None:
|
||||
"""Validate a skill name. Returns error message or None if valid."""
|
||||
if not name:
|
||||
return "Skill name is required."
|
||||
|
|
@ -105,7 +107,7 @@ def _validate_name(name: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def _validate_frontmatter(content: str) -> Optional[str]:
|
||||
def _validate_frontmatter(content: str) -> str | None:
|
||||
"""
|
||||
Validate that SKILL.md content has proper frontmatter with required fields.
|
||||
Returns error message or None if valid.
|
||||
|
|
@ -116,11 +118,11 @@ def _validate_frontmatter(content: str) -> Optional[str]:
|
|||
if not content.startswith("---"):
|
||||
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
|
||||
|
||||
end_match = re.search(r'\n---\s*\n', content[3:])
|
||||
end_match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not end_match:
|
||||
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
|
||||
|
||||
yaml_content = content[3:end_match.start() + 3]
|
||||
yaml_content = content[3 : end_match.start() + 3]
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_content)
|
||||
|
|
@ -137,7 +139,7 @@ def _validate_frontmatter(content: str) -> Optional[str]:
|
|||
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
|
||||
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
|
||||
|
||||
body = content[end_match.end() + 3:].strip()
|
||||
body = content[end_match.end() + 3 :].strip()
|
||||
if not body:
|
||||
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
|
||||
|
||||
|
|
@ -151,7 +153,7 @@ def _resolve_skill_dir(name: str, category: str = None) -> Path:
|
|||
return SKILLS_DIR / name
|
||||
|
||||
|
||||
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
|
||||
def _find_skill(name: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Find a skill by name in ~/.hermes/skills/.
|
||||
Returns {"path": Path} or None.
|
||||
|
|
@ -164,7 +166,7 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]:
|
|||
return None
|
||||
|
||||
|
||||
def _validate_file_path(file_path: str) -> Optional[str]:
|
||||
def _validate_file_path(file_path: str) -> str | None:
|
||||
"""
|
||||
Validate a file path for write_file/remove_file.
|
||||
Must be under an allowed subdirectory and not escape the skill dir.
|
||||
|
|
@ -194,7 +196,8 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
|||
# Core actions
|
||||
# =============================================================================
|
||||
|
||||
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
|
||||
|
||||
def _create_skill(name: str, content: str, category: str = None) -> dict[str, Any]:
|
||||
"""Create a new user skill with SKILL.md content."""
|
||||
# Validate name
|
||||
err = _validate_name(name)
|
||||
|
|
@ -209,10 +212,7 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
|||
# Check for name collisions across all directories
|
||||
existing = _find_skill(name)
|
||||
if existing:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"A skill named '{name}' already exists at {existing['path']}."
|
||||
}
|
||||
return {"success": False, "error": f"A skill named '{name}' already exists at {existing['path']}."}
|
||||
|
||||
# Create the skill directory
|
||||
skill_dir = _resolve_skill_dir(name, category)
|
||||
|
|
@ -238,12 +238,12 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
|||
result["category"] = category
|
||||
result["hint"] = (
|
||||
"To add reference files, templates, or scripts, use "
|
||||
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
|
||||
f"skill_manage(action='write_file', name='{name}', file_path='references/example.md', file_content='...')"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
def _edit_skill(name: str, content: str) -> dict[str, Any]:
|
||||
"""Replace the SKILL.md of any existing skill (full rewrite)."""
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
|
|
@ -278,7 +278,7 @@ def _patch_skill(
|
|||
new_string: str,
|
||||
file_path: str = None,
|
||||
replace_all: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Targeted find-and-replace within a skill file.
|
||||
|
||||
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
|
||||
|
|
@ -287,7 +287,10 @@ def _patch_skill(
|
|||
if not old_string:
|
||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
||||
if new_string is None:
|
||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "new_string is required for 'patch'. Use an empty string to delete matched text.",
|
||||
}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
|
|
@ -357,7 +360,7 @@ def _patch_skill(
|
|||
}
|
||||
|
||||
|
||||
def _delete_skill(name: str) -> Dict[str, Any]:
|
||||
def _delete_skill(name: str) -> dict[str, Any]:
|
||||
"""Delete a skill."""
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
|
|
@ -377,7 +380,7 @@ def _delete_skill(name: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
def _write_file(name: str, file_path: str, file_content: str) -> dict[str, Any]:
|
||||
"""Add or overwrite a supporting file within any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
|
|
@ -412,7 +415,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
|
||||
def _remove_file(name: str, file_path: str) -> dict[str, Any]:
|
||||
"""Remove a supporting file from any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
|
|
@ -456,6 +459,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
|
|||
# Main entry point
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def skill_manage(
|
||||
action: str,
|
||||
name: str,
|
||||
|
|
@ -474,19 +478,37 @@ def skill_manage(
|
|||
"""
|
||||
if action == "create":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body).",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _create_skill(name, content, category)
|
||||
|
||||
elif action == "edit":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _edit_skill(name, content)
|
||||
|
||||
elif action == "patch":
|
||||
if not old_string:
|
||||
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_string is required for 'patch'. Provide the text to find."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if new_string is None:
|
||||
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "new_string is required for 'patch'. Use empty string to delete matched text.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
|
||||
|
||||
elif action == "delete":
|
||||
|
|
@ -494,18 +516,31 @@ def skill_manage(
|
|||
|
||||
elif action == "write_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if file_content is None:
|
||||
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False
|
||||
)
|
||||
result = _write_file(name, file_path, file_content)
|
||||
|
||||
elif action == "remove_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False
|
||||
)
|
||||
result = _remove_file(name, file_path)
|
||||
|
||||
else:
|
||||
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
|
||||
result = {
|
||||
"success": False,
|
||||
"error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file",
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
|
@ -540,14 +575,14 @@ SKILL_MANAGE_SCHEMA = {
|
|||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
|
||||
"description": "The action to perform."
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
|
||||
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
|
||||
)
|
||||
),
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
|
|
@ -555,7 +590,7 @@ SKILL_MANAGE_SCHEMA = {
|
|||
"Full SKILL.md content (YAML frontmatter + markdown body). "
|
||||
"Required for 'create' and 'edit'. For 'edit', read the skill "
|
||||
"first with skill_view() and provide the complete updated text."
|
||||
)
|
||||
),
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
|
|
@ -563,18 +598,17 @@ SKILL_MANAGE_SCHEMA = {
|
|||
"Text to find in the file (required for 'patch'). Must be unique "
|
||||
"unless replace_all=true. Include enough surrounding context to "
|
||||
"ensure uniqueness."
|
||||
)
|
||||
),
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Replacement text (required for 'patch'). Can be empty string "
|
||||
"to delete the matched text."
|
||||
)
|
||||
"Replacement text (required for 'patch'). Can be empty string to delete the matched text."
|
||||
),
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
|
||||
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false).",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
|
|
@ -582,7 +616,7 @@ SKILL_MANAGE_SCHEMA = {
|
|||
"Optional category/domain for organizing the skill (e.g., 'devops', "
|
||||
"'data-science', 'mlops'). Creates a subdirectory grouping. "
|
||||
"Only used with 'create'."
|
||||
)
|
||||
),
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
|
|
@ -591,12 +625,9 @@ SKILL_MANAGE_SCHEMA = {
|
|||
"For 'write_file'/'remove_file': required, must be under references/, "
|
||||
"templates/, scripts/, or assets/. "
|
||||
"For 'patch': optional, defaults to SKILL.md if omitted."
|
||||
)
|
||||
},
|
||||
"file_content": {
|
||||
"type": "string",
|
||||
"description": "Content for the file. Required for 'write_file'."
|
||||
),
|
||||
},
|
||||
"file_content": {"type": "string", "description": "Content for the file. Required for 'write_file'."},
|
||||
},
|
||||
"required": ["action", "name"],
|
||||
},
|
||||
|
|
@ -619,5 +650,6 @@ registry.register(
|
|||
file_content=args.get("file_content"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False)),
|
||||
replace_all=args.get("replace_all", False),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -23,15 +23,17 @@ import subprocess
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from tools.skills_guard import (
|
||||
ScanResult, scan_skill, should_allow_install, content_hash, TRUSTED_REPOS,
|
||||
TRUSTED_REPOS,
|
||||
ScanResult,
|
||||
content_hash,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -58,24 +60,27 @@ INDEX_CACHE_TTL = 3600 # 1 hour
|
|||
# Data models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMeta:
|
||||
"""Minimal metadata returned by search results."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub"
|
||||
identifier: str # source-specific ID (e.g. "openai/skills/skill-creator")
|
||||
trust_level: str # "builtin" | "trusted" | "community"
|
||||
repo: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
source: str # "official", "github", "clawhub", "claude-marketplace", "lobehub"
|
||||
identifier: str # source-specific ID (e.g. "openai/skills/skill-creator")
|
||||
trust_level: str # "builtin" | "trusted" | "community"
|
||||
repo: str | None = None
|
||||
path: str | None = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillBundle:
|
||||
"""A downloaded skill ready for quarantine/scanning/installation."""
|
||||
|
||||
name: str
|
||||
files: Dict[str, str] # relative_path -> text content
|
||||
files: dict[str, str] # relative_path -> text content
|
||||
source: str
|
||||
identifier: str
|
||||
trust_level: str
|
||||
|
|
@ -85,6 +90,7 @@ class SkillBundle:
|
|||
# GitHub Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GitHubAuth:
|
||||
"""
|
||||
GitHub API authentication. Tries methods in priority order:
|
||||
|
|
@ -95,11 +101,11 @@ class GitHubAuth:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cached_token: Optional[str] = None
|
||||
self._cached_method: Optional[str] = None
|
||||
self._cached_token: str | None = None
|
||||
self._cached_method: str | None = None
|
||||
self._app_token_expiry: float = 0
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""Return authorization headers for GitHub API requests."""
|
||||
token = self._resolve_token()
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
|
|
@ -115,7 +121,7 @@ class GitHubAuth:
|
|||
self._resolve_token()
|
||||
return self._cached_method or "anonymous"
|
||||
|
||||
def _resolve_token(self) -> Optional[str]:
|
||||
def _resolve_token(self) -> str | None:
|
||||
# Return cached token if still valid
|
||||
if self._cached_token:
|
||||
if self._cached_method != "github-app" or time.time() < self._app_token_expiry:
|
||||
|
|
@ -146,12 +152,14 @@ class GitHubAuth:
|
|||
self._cached_method = "anonymous"
|
||||
return None
|
||||
|
||||
def _try_gh_cli(self) -> Optional[str]:
|
||||
def _try_gh_cli(self) -> str | None:
|
||||
"""Try to get a token from the gh CLI."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gh", "auth", "token"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
|
|
@ -159,7 +167,7 @@ class GitHubAuth:
|
|||
logger.debug("gh CLI token lookup failed: %s", e)
|
||||
return None
|
||||
|
||||
def _try_github_app(self) -> Optional[str]:
|
||||
def _try_github_app(self) -> str | None:
|
||||
"""Try GitHub App JWT authentication if credentials are configured."""
|
||||
app_id = os.environ.get("GITHUB_APP_ID")
|
||||
key_path = os.environ.get("GITHUB_APP_PRIVATE_KEY_PATH")
|
||||
|
|
@ -208,21 +216,22 @@ class GitHubAuth:
|
|||
# Source adapter interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SkillSource(ABC):
|
||||
"""Abstract base for all skill registry adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
"""Search for skills matching a query string."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
"""Download a skill bundle by identifier."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
"""Fetch metadata for a skill without downloading all files."""
|
||||
...
|
||||
|
||||
|
|
@ -240,6 +249,7 @@ class SkillSource(ABC):
|
|||
# GitHub source adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GitHubSource(SkillSource):
|
||||
"""Fetch skills from GitHub repos via the Contents API."""
|
||||
|
||||
|
|
@ -249,7 +259,7 @@ class GitHubSource(SkillSource):
|
|||
{"repo": "VoltAgent/awesome-agent-skills", "path": "skills/"},
|
||||
]
|
||||
|
||||
def __init__(self, auth: GitHubAuth, extra_taps: Optional[List[Dict]] = None):
|
||||
def __init__(self, auth: GitHubAuth, extra_taps: list[dict] | None = None):
|
||||
self.auth = auth
|
||||
self.taps = list(self.DEFAULT_TAPS)
|
||||
if extra_taps:
|
||||
|
|
@ -267,9 +277,9 @@ class GitHubSource(SkillSource):
|
|||
return "trusted"
|
||||
return "community"
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
"""Search all taps for skills matching the query."""
|
||||
results: List[SkillMeta] = []
|
||||
results: list[SkillMeta] = []
|
||||
query_lower = query.lower()
|
||||
|
||||
for tap in self.taps:
|
||||
|
|
@ -287,15 +297,13 @@ class GitHubSource(SkillSource):
|
|||
_trust_rank = {"builtin": 2, "trusted": 1, "community": 0}
|
||||
seen = {}
|
||||
for r in results:
|
||||
if r.name not in seen:
|
||||
seen[r.name] = r
|
||||
elif _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0):
|
||||
if r.name not in seen or _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
results = list(seen.values())
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
"""
|
||||
Download a skill from GitHub.
|
||||
identifier format: "owner/repo/path/to/skill-dir"
|
||||
|
|
@ -322,7 +330,7 @@ class GitHubSource(SkillSource):
|
|||
trust_level=trust,
|
||||
)
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
"""Fetch just the SKILL.md metadata for preview."""
|
||||
parts = identifier.split("/", 2)
|
||||
if len(parts) < 3:
|
||||
|
|
@ -363,7 +371,7 @@ class GitHubSource(SkillSource):
|
|||
|
||||
# -- Internal helpers --
|
||||
|
||||
def _list_skills_in_repo(self, repo: str, path: str) -> List[SkillMeta]:
|
||||
def _list_skills_in_repo(self, repo: str, path: str) -> list[SkillMeta]:
|
||||
"""List skill directories in a GitHub repo path, using cached index."""
|
||||
cache_key = f"{repo}_{path}".replace("/", "_").replace(" ", "_")
|
||||
cached = self._read_cache(cache_key)
|
||||
|
|
@ -382,7 +390,7 @@ class GitHubSource(SkillSource):
|
|||
if not isinstance(entries, list):
|
||||
return []
|
||||
|
||||
skills: List[SkillMeta] = []
|
||||
skills: list[SkillMeta] = []
|
||||
for entry in entries:
|
||||
if entry.get("type") != "dir":
|
||||
continue
|
||||
|
|
@ -400,7 +408,7 @@ class GitHubSource(SkillSource):
|
|||
self._write_cache(cache_key, [self._meta_to_dict(s) for s in skills])
|
||||
return skills
|
||||
|
||||
def _download_directory(self, repo: str, path: str) -> Dict[str, str]:
|
||||
def _download_directory(self, repo: str, path: str) -> dict[str, str]:
|
||||
"""Recursively download all text files from a GitHub directory."""
|
||||
url = f"https://api.github.com/repos/{repo}/contents/{path.rstrip('/')}"
|
||||
try:
|
||||
|
|
@ -414,7 +422,7 @@ class GitHubSource(SkillSource):
|
|||
if not isinstance(entries, list):
|
||||
return {}
|
||||
|
||||
files: Dict[str, str] = {}
|
||||
files: dict[str, str] = {}
|
||||
for entry in entries:
|
||||
name = entry.get("name", "")
|
||||
entry_type = entry.get("type", "")
|
||||
|
|
@ -431,7 +439,7 @@ class GitHubSource(SkillSource):
|
|||
|
||||
return files
|
||||
|
||||
def _fetch_file_content(self, repo: str, path: str) -> Optional[str]:
|
||||
def _fetch_file_content(self, repo: str, path: str) -> str | None:
|
||||
"""Fetch a single file's content from GitHub."""
|
||||
url = f"https://api.github.com/repos/{repo}/contents/{path}"
|
||||
try:
|
||||
|
|
@ -446,7 +454,7 @@ class GitHubSource(SkillSource):
|
|||
logger.debug("GitHub contents API fetch failed: %s", e)
|
||||
return None
|
||||
|
||||
def _read_cache(self, key: str) -> Optional[list]:
|
||||
def _read_cache(self, key: str) -> list | None:
|
||||
"""Read cached index if not expired."""
|
||||
cache_file = INDEX_CACHE_DIR / f"{key}.json"
|
||||
if not cache_file.exists():
|
||||
|
|
@ -486,10 +494,10 @@ class GitHubSource(SkillSource):
|
|||
"""Parse YAML frontmatter from SKILL.md content."""
|
||||
if not content.startswith("---"):
|
||||
return {}
|
||||
match = re.search(r'\n---\s*\n', content[3:])
|
||||
match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not match:
|
||||
return {}
|
||||
yaml_text = content[3:match.start() + 3]
|
||||
yaml_text = content[3 : match.start() + 3]
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_text)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
|
@ -501,6 +509,7 @@ class GitHubSource(SkillSource):
|
|||
# ClawHub source adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ClawHubSource(SkillSource):
|
||||
"""
|
||||
Fetch skills from ClawHub (clawhub.ai) via their HTTP API.
|
||||
|
|
@ -516,7 +525,7 @@ class ClawHubSource(SkillSource):
|
|||
def trust_level_for(self, identifier: str) -> str:
|
||||
return "community"
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
cache_key = f"clawhub_search_{hashlib.md5(query.encode()).hexdigest()}"
|
||||
cached = _read_index_cache(cache_key)
|
||||
if cached is not None:
|
||||
|
|
@ -548,19 +557,21 @@ class ClawHubSource(SkillSource):
|
|||
tags = item.get("tags", [])
|
||||
if not isinstance(tags, list):
|
||||
tags = []
|
||||
results.append(SkillMeta(
|
||||
name=display_name,
|
||||
description=summary,
|
||||
source="clawhub",
|
||||
identifier=slug,
|
||||
trust_level="community",
|
||||
tags=[str(t) for t in tags],
|
||||
))
|
||||
results.append(
|
||||
SkillMeta(
|
||||
name=display_name,
|
||||
description=summary,
|
||||
source="clawhub",
|
||||
identifier=slug,
|
||||
trust_level="community",
|
||||
tags=[str(t) for t in tags],
|
||||
)
|
||||
)
|
||||
|
||||
_write_index_cache(cache_key, [_skill_meta_to_dict(s) for s in results])
|
||||
return results
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
slug = identifier.split("/")[-1]
|
||||
|
||||
skill_data = self._get_json(f"{self.BASE_URL}/skills/{slug}")
|
||||
|
|
@ -593,7 +604,7 @@ class ClawHubSource(SkillSource):
|
|||
trust_level="community",
|
||||
)
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
slug = identifier.split("/")[-1]
|
||||
data = self._get_json(f"{self.BASE_URL}/skills/{slug}")
|
||||
if not isinstance(data, dict):
|
||||
|
|
@ -612,7 +623,7 @@ class ClawHubSource(SkillSource):
|
|||
tags=[str(t) for t in tags],
|
||||
)
|
||||
|
||||
def _get_json(self, url: str, timeout: int = 20) -> Optional[Any]:
|
||||
def _get_json(self, url: str, timeout: int = 20) -> Any | None:
|
||||
try:
|
||||
resp = httpx.get(url, timeout=timeout)
|
||||
if resp.status_code != 200:
|
||||
|
|
@ -621,7 +632,7 @@ class ClawHubSource(SkillSource):
|
|||
except (httpx.HTTPError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def _resolve_latest_version(self, slug: str, skill_data: Dict[str, Any]) -> Optional[str]:
|
||||
def _resolve_latest_version(self, slug: str, skill_data: dict[str, Any]) -> str | None:
|
||||
latest = skill_data.get("latestVersion")
|
||||
if isinstance(latest, dict):
|
||||
version = latest.get("version")
|
||||
|
|
@ -643,8 +654,8 @@ class ClawHubSource(SkillSource):
|
|||
return version
|
||||
return None
|
||||
|
||||
def _extract_files(self, version_data: Dict[str, Any]) -> Dict[str, str]:
|
||||
files: Dict[str, str] = {}
|
||||
def _extract_files(self, version_data: dict[str, Any]) -> dict[str, str]:
|
||||
files: dict[str, str] = {}
|
||||
file_list = version_data.get("files")
|
||||
|
||||
if isinstance(file_list, dict):
|
||||
|
|
@ -674,7 +685,7 @@ class ClawHubSource(SkillSource):
|
|||
|
||||
return files
|
||||
|
||||
def _fetch_text(self, url: str) -> Optional[str]:
|
||||
def _fetch_text(self, url: str) -> str | None:
|
||||
try:
|
||||
resp = httpx.get(url, timeout=20)
|
||||
if resp.status_code == 200:
|
||||
|
|
@ -688,6 +699,7 @@ class ClawHubSource(SkillSource):
|
|||
# Claude Code marketplace source adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ClaudeMarketplaceSource(SkillSource):
|
||||
"""
|
||||
Discover skills from Claude Code marketplace repos.
|
||||
|
|
@ -713,8 +725,8 @@ class ClaudeMarketplaceSource(SkillSource):
|
|||
return "trusted"
|
||||
return "community"
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
results: List[SkillMeta] = []
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
results: list[SkillMeta] = []
|
||||
query_lower = query.lower()
|
||||
|
||||
for marketplace_repo in self.KNOWN_MARKETPLACES:
|
||||
|
|
@ -730,18 +742,20 @@ class ClaudeMarketplaceSource(SkillSource):
|
|||
else:
|
||||
identifier = f"{marketplace_repo}/{source_path}"
|
||||
|
||||
results.append(SkillMeta(
|
||||
name=plugin.get("name", ""),
|
||||
description=plugin.get("description", ""),
|
||||
source="claude-marketplace",
|
||||
identifier=identifier,
|
||||
trust_level=self.trust_level_for(identifier),
|
||||
repo=marketplace_repo,
|
||||
))
|
||||
results.append(
|
||||
SkillMeta(
|
||||
name=plugin.get("name", ""),
|
||||
description=plugin.get("description", ""),
|
||||
source="claude-marketplace",
|
||||
identifier=identifier,
|
||||
trust_level=self.trust_level_for(identifier),
|
||||
repo=marketplace_repo,
|
||||
)
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
# Delegate to GitHub Contents API since marketplace skills live in GitHub repos
|
||||
gh = GitHubSource(auth=self.auth)
|
||||
bundle = gh.fetch(identifier)
|
||||
|
|
@ -749,7 +763,7 @@ class ClaudeMarketplaceSource(SkillSource):
|
|||
bundle.source = "claude-marketplace"
|
||||
return bundle
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
gh = GitHubSource(auth=self.auth)
|
||||
meta = gh.inspect(identifier)
|
||||
if meta:
|
||||
|
|
@ -757,7 +771,7 @@ class ClaudeMarketplaceSource(SkillSource):
|
|||
meta.trust_level = self.trust_level_for(identifier)
|
||||
return meta
|
||||
|
||||
def _fetch_marketplace_index(self, repo: str) -> List[dict]:
|
||||
def _fetch_marketplace_index(self, repo: str) -> list[dict]:
|
||||
"""Fetch and parse .claude-plugin/marketplace.json from a repo."""
|
||||
cache_key = f"claude_marketplace_{repo.replace('/', '_')}"
|
||||
cached = _read_index_cache(cache_key)
|
||||
|
|
@ -786,6 +800,7 @@ class ClaudeMarketplaceSource(SkillSource):
|
|||
# LobeHub source adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LobeHubSource(SkillSource):
|
||||
"""
|
||||
Fetch skills from LobeHub's agent marketplace (14,500+ agents).
|
||||
|
|
@ -802,13 +817,13 @@ class LobeHubSource(SkillSource):
|
|||
def trust_level_for(self, identifier: str) -> str:
|
||||
return "community"
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
index = self._fetch_index()
|
||||
if not index:
|
||||
return []
|
||||
|
||||
query_lower = query.lower()
|
||||
results: List[SkillMeta] = []
|
||||
results: list[SkillMeta] = []
|
||||
|
||||
agents = index.get("agents", index) if isinstance(index, dict) else index
|
||||
if not isinstance(agents, list):
|
||||
|
|
@ -823,21 +838,23 @@ class LobeHubSource(SkillSource):
|
|||
searchable = f"{title} {desc} {' '.join(tags) if isinstance(tags, list) else ''}".lower()
|
||||
if query_lower in searchable:
|
||||
identifier = agent.get("identifier", title.lower().replace(" ", "-"))
|
||||
results.append(SkillMeta(
|
||||
name=identifier,
|
||||
description=desc[:200],
|
||||
source="lobehub",
|
||||
identifier=f"lobehub/{identifier}",
|
||||
trust_level="community",
|
||||
tags=tags if isinstance(tags, list) else [],
|
||||
))
|
||||
results.append(
|
||||
SkillMeta(
|
||||
name=identifier,
|
||||
description=desc[:200],
|
||||
source="lobehub",
|
||||
identifier=f"lobehub/{identifier}",
|
||||
trust_level="community",
|
||||
tags=tags if isinstance(tags, list) else [],
|
||||
)
|
||||
)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
# Strip "lobehub/" prefix if present
|
||||
agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier
|
||||
|
||||
|
|
@ -854,7 +871,7 @@ class LobeHubSource(SkillSource):
|
|||
trust_level="community",
|
||||
)
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
agent_id = identifier.split("/", 1)[-1] if identifier.startswith("lobehub/") else identifier
|
||||
index = self._fetch_index()
|
||||
if not index:
|
||||
|
|
@ -877,7 +894,7 @@ class LobeHubSource(SkillSource):
|
|||
)
|
||||
return None
|
||||
|
||||
def _fetch_index(self) -> Optional[Any]:
|
||||
def _fetch_index(self) -> Any | None:
|
||||
"""Fetch the LobeHub agent index (cached for 1 hour)."""
|
||||
cache_key = "lobehub_index"
|
||||
cached = _read_index_cache(cache_key)
|
||||
|
|
@ -895,7 +912,7 @@ class LobeHubSource(SkillSource):
|
|||
_write_index_cache(cache_key, data)
|
||||
return data
|
||||
|
||||
def _fetch_agent(self, agent_id: str) -> Optional[dict]:
|
||||
def _fetch_agent(self, agent_id: str) -> dict | None:
|
||||
"""Fetch a single agent's JSON file."""
|
||||
url = f"https://chat-agents.lobehub.com/{agent_id}.json"
|
||||
try:
|
||||
|
|
@ -924,8 +941,8 @@ class LobeHubSource(SkillSource):
|
|||
"metadata:",
|
||||
" hermes:",
|
||||
f" tags: [{', '.join(str(t) for t in tag_list)}]",
|
||||
f" lobehub:",
|
||||
f" source: lobehub",
|
||||
" lobehub:",
|
||||
" source: lobehub",
|
||||
"---",
|
||||
]
|
||||
|
||||
|
|
@ -946,6 +963,7 @@ class LobeHubSource(SkillSource):
|
|||
# Official optional skills source adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OptionalSkillSource(SkillSource):
|
||||
"""
|
||||
Fetch skills from the optional-skills/ directory shipped with the repo.
|
||||
|
|
@ -967,8 +985,8 @@ class OptionalSkillSource(SkillSource):
|
|||
|
||||
# -- search -----------------------------------------------------------
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> List[SkillMeta]:
|
||||
results: List[SkillMeta] = []
|
||||
def search(self, query: str, limit: int = 10) -> list[SkillMeta]:
|
||||
results: list[SkillMeta] = []
|
||||
query_lower = query.lower()
|
||||
|
||||
for meta in self._scan_all():
|
||||
|
|
@ -982,7 +1000,7 @@ class OptionalSkillSource(SkillSource):
|
|||
|
||||
# -- fetch ------------------------------------------------------------
|
||||
|
||||
def fetch(self, identifier: str) -> Optional[SkillBundle]:
|
||||
def fetch(self, identifier: str) -> SkillBundle | None:
|
||||
# identifier format: "official/category/skill" or "official/skill"
|
||||
rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier
|
||||
skill_dir = self._optional_dir / rel
|
||||
|
|
@ -1004,7 +1022,7 @@ class OptionalSkillSource(SkillSource):
|
|||
else:
|
||||
skill_dir = resolved
|
||||
|
||||
files: Dict[str, str] = {}
|
||||
files: dict[str, str] = {}
|
||||
for f in skill_dir.rglob("*"):
|
||||
if f.is_file() and not f.name.startswith("."):
|
||||
rel_path = str(f.relative_to(skill_dir))
|
||||
|
|
@ -1029,7 +1047,7 @@ class OptionalSkillSource(SkillSource):
|
|||
|
||||
# -- inspect ----------------------------------------------------------
|
||||
|
||||
def inspect(self, identifier: str) -> Optional[SkillMeta]:
|
||||
def inspect(self, identifier: str) -> SkillMeta | None:
|
||||
rel = identifier.split("/", 1)[-1] if identifier.startswith("official/") else identifier
|
||||
skill_name = rel.rsplit("/", 1)[-1]
|
||||
|
||||
|
|
@ -1040,7 +1058,7 @@ class OptionalSkillSource(SkillSource):
|
|||
|
||||
# -- internal helpers -------------------------------------------------
|
||||
|
||||
def _find_skill_dir(self, name: str) -> Optional[Path]:
|
||||
def _find_skill_dir(self, name: str) -> Path | None:
|
||||
"""Find a skill directory by name anywhere in optional-skills/."""
|
||||
if not self._optional_dir.is_dir():
|
||||
return None
|
||||
|
|
@ -1049,12 +1067,12 @@ class OptionalSkillSource(SkillSource):
|
|||
return skill_md.parent
|
||||
return None
|
||||
|
||||
def _scan_all(self) -> List[SkillMeta]:
|
||||
def _scan_all(self) -> list[SkillMeta]:
|
||||
"""Enumerate all optional skills with metadata."""
|
||||
if not self._optional_dir.is_dir():
|
||||
return []
|
||||
|
||||
results: List[SkillMeta] = []
|
||||
results: list[SkillMeta] = []
|
||||
for skill_md in sorted(self._optional_dir.rglob("SKILL.md")):
|
||||
parent = skill_md.parent
|
||||
rel_parts = parent.relative_to(self._optional_dir).parts
|
||||
|
|
@ -1078,15 +1096,17 @@ class OptionalSkillSource(SkillSource):
|
|||
|
||||
rel_path = str(parent.relative_to(self._optional_dir))
|
||||
|
||||
results.append(SkillMeta(
|
||||
name=name,
|
||||
description=desc[:200],
|
||||
source="official",
|
||||
identifier=f"official/{rel_path}",
|
||||
trust_level="builtin",
|
||||
path=rel_path,
|
||||
tags=tags if isinstance(tags, list) else [],
|
||||
))
|
||||
results.append(
|
||||
SkillMeta(
|
||||
name=name,
|
||||
description=desc[:200],
|
||||
source="official",
|
||||
identifier=f"official/{rel_path}",
|
||||
trust_level="builtin",
|
||||
path=rel_path,
|
||||
tags=tags if isinstance(tags, list) else [],
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
|
@ -1095,10 +1115,10 @@ class OptionalSkillSource(SkillSource):
|
|||
"""Parse YAML frontmatter from SKILL.md content."""
|
||||
if not content.startswith("---"):
|
||||
return {}
|
||||
match = re.search(r'\n---\s*\n', content[3:])
|
||||
match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not match:
|
||||
return {}
|
||||
yaml_text = content[3:match.start() + 3]
|
||||
yaml_text = content[3 : match.start() + 3]
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_text)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
|
@ -1110,7 +1130,8 @@ class OptionalSkillSource(SkillSource):
|
|||
# Shared cache helpers (used by multiple adapters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _read_index_cache(key: str) -> Optional[Any]:
|
||||
|
||||
def _read_index_cache(key: str) -> Any | None:
|
||||
"""Read cached data if not expired."""
|
||||
cache_file = INDEX_CACHE_DIR / f"{key}.json"
|
||||
if not cache_file.exists():
|
||||
|
|
@ -1152,6 +1173,7 @@ def _skill_meta_to_dict(meta: SkillMeta) -> dict:
|
|||
# Lock file management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HubLockFile:
|
||||
"""Manages skills/.hub/lock.json — tracks provenance of installed hub skills."""
|
||||
|
||||
|
|
@ -1179,7 +1201,7 @@ class HubLockFile:
|
|||
scan_verdict: str,
|
||||
skill_hash: str,
|
||||
install_path: str,
|
||||
files: List[str],
|
||||
files: list[str],
|
||||
) -> None:
|
||||
data = self.load()
|
||||
data["installed"][name] = {
|
||||
|
|
@ -1190,8 +1212,8 @@ class HubLockFile:
|
|||
"content_hash": skill_hash,
|
||||
"install_path": install_path,
|
||||
"files": files,
|
||||
"installed_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"installed_at": datetime.now(UTC).isoformat(),
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self.save(data)
|
||||
|
||||
|
|
@ -1200,11 +1222,11 @@ class HubLockFile:
|
|||
data["installed"].pop(name, None)
|
||||
self.save(data)
|
||||
|
||||
def get_installed(self, name: str) -> Optional[dict]:
|
||||
def get_installed(self, name: str) -> dict | None:
|
||||
data = self.load()
|
||||
return data["installed"].get(name)
|
||||
|
||||
def list_installed(self) -> List[dict]:
|
||||
def list_installed(self) -> list[dict]:
|
||||
data = self.load()
|
||||
result = []
|
||||
for name, entry in data["installed"].items():
|
||||
|
|
@ -1220,13 +1242,14 @@ class HubLockFile:
|
|||
# Taps management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TapsManager:
|
||||
"""Manages the taps.json file — custom GitHub repo sources."""
|
||||
|
||||
def __init__(self, path: Path = TAPS_FILE):
|
||||
self.path = path
|
||||
|
||||
def load(self) -> List[dict]:
|
||||
def load(self) -> list[dict]:
|
||||
if not self.path.exists():
|
||||
return []
|
||||
try:
|
||||
|
|
@ -1235,7 +1258,7 @@ class TapsManager:
|
|||
except (json.JSONDecodeError, OSError):
|
||||
return []
|
||||
|
||||
def save(self, taps: List[dict]) -> None:
|
||||
def save(self, taps: list[dict]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.path.write_text(json.dumps({"taps": taps}, indent=2) + "\n")
|
||||
|
||||
|
|
@ -1257,7 +1280,7 @@ class TapsManager:
|
|||
self.save(new_taps)
|
||||
return True
|
||||
|
||||
def list_taps(self) -> List[dict]:
|
||||
def list_taps(self) -> list[dict]:
|
||||
return self.load()
|
||||
|
||||
|
||||
|
|
@ -1265,11 +1288,13 @@ class TapsManager:
|
|||
# Audit log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def append_audit_log(action: str, skill_name: str, source: str,
|
||||
trust_level: str, verdict: str, extra: str = "") -> None:
|
||||
|
||||
def append_audit_log(
|
||||
action: str, skill_name: str, source: str, trust_level: str, verdict: str, extra: str = ""
|
||||
) -> None:
|
||||
"""Append a line to the audit log."""
|
||||
AUDIT_LOG.parent.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
timestamp = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
parts = [timestamp, action, skill_name, f"{source}:{trust_level}", verdict]
|
||||
if extra:
|
||||
parts.append(extra)
|
||||
|
|
@ -1285,6 +1310,7 @@ def append_audit_log(action: str, skill_name: str, source: str,
|
|||
# Hub operations (high-level)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ensure_hub_dirs() -> None:
|
||||
"""Create the .hub directory structure if it doesn't exist."""
|
||||
HUB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -1347,15 +1373,18 @@ def install_from_quarantine(
|
|||
)
|
||||
|
||||
append_audit_log(
|
||||
"INSTALL", skill_name, bundle.source,
|
||||
bundle.trust_level, scan_result.verdict,
|
||||
"INSTALL",
|
||||
skill_name,
|
||||
bundle.source,
|
||||
bundle.trust_level,
|
||||
scan_result.verdict,
|
||||
content_hash(install_dir),
|
||||
)
|
||||
|
||||
return install_dir
|
||||
|
||||
|
||||
def uninstall_skill(skill_name: str) -> Tuple[bool, str]:
|
||||
def uninstall_skill(skill_name: str) -> tuple[bool, str]:
|
||||
"""Remove a hub-installed skill. Refuses to remove builtins."""
|
||||
lock = HubLockFile()
|
||||
entry = lock.get_installed(skill_name)
|
||||
|
|
@ -1372,7 +1401,7 @@ def uninstall_skill(skill_name: str) -> Tuple[bool, str]:
|
|||
return True, f"Uninstalled '{skill_name}' from {entry['install_path']}"
|
||||
|
||||
|
||||
def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]:
|
||||
def create_source_router(auth: GitHubAuth | None = None) -> list[SkillSource]:
|
||||
"""
|
||||
Create all configured source adapters.
|
||||
Returns a list of active sources for search/fetch operations.
|
||||
|
|
@ -1383,8 +1412,8 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]
|
|||
taps_mgr = TapsManager()
|
||||
extra_taps = taps_mgr.list_taps()
|
||||
|
||||
sources: List[SkillSource] = [
|
||||
OptionalSkillSource(), # Official optional skills (highest priority)
|
||||
sources: list[SkillSource] = [
|
||||
OptionalSkillSource(), # Official optional skills (highest priority)
|
||||
GitHubSource(auth=auth, extra_taps=extra_taps),
|
||||
ClawHubSource(),
|
||||
ClaudeMarketplaceSource(auth=auth),
|
||||
|
|
@ -1394,10 +1423,11 @@ def create_source_router(auth: Optional[GitHubAuth] = None) -> List[SkillSource]
|
|||
return sources
|
||||
|
||||
|
||||
def unified_search(query: str, sources: List[SkillSource],
|
||||
source_filter: str = "all", limit: int = 10) -> List[SkillMeta]:
|
||||
def unified_search(
|
||||
query: str, sources: list[SkillSource], source_filter: str = "all", limit: int = 10
|
||||
) -> list[SkillMeta]:
|
||||
"""Search all sources and merge results."""
|
||||
all_results: List[SkillMeta] = []
|
||||
all_results: list[SkillMeta] = []
|
||||
|
||||
for src in sources:
|
||||
if source_filter != "all" and src.source_id() != source_filter:
|
||||
|
|
@ -1410,11 +1440,9 @@ def unified_search(query: str, sources: List[SkillSource],
|
|||
|
||||
# Deduplicate by name, preferring higher trust levels
|
||||
_TRUST_RANK = {"builtin": 2, "trusted": 1, "community": 0}
|
||||
seen: Dict[str, SkillMeta] = {}
|
||||
seen: dict[str, SkillMeta] = {}
|
||||
for r in all_results:
|
||||
if r.name not in seen:
|
||||
seen[r.name] = r
|
||||
elif _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
if r.name not in seen or _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
deduped = list(seen.values())
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import logging
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,7 +40,7 @@ def _get_bundled_dir() -> Path:
|
|||
return Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
def _read_manifest() -> Dict[str, str]:
|
||||
def _read_manifest() -> dict[str, str]:
|
||||
"""
|
||||
Read the manifest as a dict of {skill_name: origin_hash}.
|
||||
|
||||
|
|
@ -64,11 +63,11 @@ def _read_manifest() -> Dict[str, str]:
|
|||
# v1 format: plain name — empty hash triggers migration
|
||||
result[line] = ""
|
||||
return result
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
return {}
|
||||
|
||||
|
||||
def _write_manifest(entries: Dict[str, str]):
|
||||
def _write_manifest(entries: dict[str, str]):
|
||||
"""Write the manifest file atomically in v2 format (name:hash).
|
||||
|
||||
Uses a temp file + os.replace() to avoid corruption if the process
|
||||
|
|
@ -101,7 +100,7 @@ def _write_manifest(entries: Dict[str, str]):
|
|||
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
|
||||
|
||||
|
||||
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
|
||||
def _discover_bundled_skills(bundled_dir: Path) -> list[tuple[str, Path]]:
|
||||
"""
|
||||
Find all SKILL.md files in the bundled directory.
|
||||
Returns list of (skill_name, skill_directory_path) tuples.
|
||||
|
|
@ -139,7 +138,7 @@ def _dir_hash(directory: Path) -> str:
|
|||
rel = fpath.relative_to(directory)
|
||||
hasher.update(str(rel).encode("utf-8"))
|
||||
hasher.update(fpath.read_bytes())
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
pass
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
|
@ -155,8 +154,12 @@ def sync_skills(quiet: bool = False) -> dict:
|
|||
bundled_dir = _get_bundled_dir()
|
||||
if not bundled_dir.exists():
|
||||
return {
|
||||
"copied": [], "updated": [], "skipped": 0,
|
||||
"user_modified": [], "cleaned": [], "total_bundled": 0,
|
||||
"copied": [],
|
||||
"updated": [],
|
||||
"skipped": 0,
|
||||
"user_modified": [],
|
||||
"cleaned": [],
|
||||
"total_bundled": 0,
|
||||
}
|
||||
|
||||
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -187,7 +190,7 @@ def sync_skills(quiet: bool = False) -> dict:
|
|||
manifest[skill_name] = bundled_hash
|
||||
if not quiet:
|
||||
print(f" + {skill_name}")
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to copy {skill_name}: {e}")
|
||||
# Do NOT add to manifest — next sync should retry
|
||||
|
|
@ -229,12 +232,12 @@ def sync_skills(quiet: bool = False) -> dict:
|
|||
print(f" ↑ {skill_name} (updated)")
|
||||
# Remove backup after successful copy
|
||||
shutil.rmtree(backup, ignore_errors=True)
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
# Restore from backup
|
||||
if backup.exists() and not dest.exists():
|
||||
shutil.move(str(backup), str(dest))
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to update {skill_name}: {e}")
|
||||
else:
|
||||
|
|
@ -257,7 +260,7 @@ def sync_skills(quiet: bool = False) -> dict:
|
|||
try:
|
||||
dest_desc.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(desc_md, dest_desc)
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
logger.debug("Could not copy %s: %s", desc_md, e)
|
||||
|
||||
_write_manifest(manifest)
|
||||
|
|
|
|||
|
|
@ -40,9 +40,9 @@ SKILL.md Format (YAML Frontmatter, agentskills.io compatible):
|
|||
tags: [fine-tuning, llm]
|
||||
related_skills: [peft, lora]
|
||||
---
|
||||
|
||||
|
||||
# Skill Title
|
||||
|
||||
|
||||
Full instructions and content here...
|
||||
|
||||
Available tools:
|
||||
|
|
@ -51,13 +51,13 @@ Available tools:
|
|||
|
||||
Usage:
|
||||
from tools.skills_tool import skills_list, skill_view, check_skills_requirements
|
||||
|
||||
|
||||
# List all skills (returns metadata only - token efficient)
|
||||
result = skills_list()
|
||||
|
||||
|
||||
# View a skill's main content (loads full instructions)
|
||||
content = skill_view("axolotl")
|
||||
|
||||
|
||||
# View a reference file within a skill (loads linked file)
|
||||
content = skill_view("axolotl", "references/dataset-formats.md")
|
||||
"""
|
||||
|
|
@ -67,11 +67,10 @@ import os
|
|||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
# All skills live in ~/.hermes/skills/ (seeded from bundled skills/ on install).
|
||||
# This is the single source of truth -- agent edits, hub installs, and bundled
|
||||
# skills all coexist here without polluting the git repo.
|
||||
|
|
@ -91,7 +90,7 @@ _PLATFORM_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool:
|
||||
def skill_matches_platform(frontmatter: dict[str, Any]) -> bool:
|
||||
"""Check if a skill is compatible with the current OS platform.
|
||||
|
||||
Skills declare platform requirements via a top-level ``platforms`` list
|
||||
|
|
@ -123,28 +122,28 @@ def check_skills_requirements() -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]:
|
||||
"""
|
||||
Parse YAML frontmatter from markdown content.
|
||||
|
||||
|
||||
Uses yaml.safe_load for full YAML support (nested metadata, lists, etc.)
|
||||
with a fallback to simple key:value splitting for robustness.
|
||||
|
||||
|
||||
Args:
|
||||
content: Full markdown file content
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (frontmatter dict, remaining content)
|
||||
"""
|
||||
frontmatter = {}
|
||||
body = content
|
||||
|
||||
|
||||
if content.startswith("---"):
|
||||
end_match = re.search(r'\n---\s*\n', content[3:])
|
||||
end_match = re.search(r"\n---\s*\n", content[3:])
|
||||
if end_match:
|
||||
yaml_content = content[3:end_match.start() + 3]
|
||||
body = content[end_match.end() + 3:]
|
||||
|
||||
yaml_content = content[3 : end_match.start() + 3]
|
||||
body = content[end_match.end() + 3 :]
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_content)
|
||||
if isinstance(parsed, dict):
|
||||
|
|
@ -152,18 +151,18 @@ def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
|
|||
# yaml.safe_load returns None for empty frontmatter
|
||||
except yaml.YAMLError:
|
||||
# Fallback: simple key:value parsing for malformed YAML
|
||||
for line in yaml_content.strip().split('\n'):
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
for line in yaml_content.strip().split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
frontmatter[key.strip()] = value.strip()
|
||||
|
||||
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
def _get_category_from_path(skill_path: Path) -> Optional[str]:
|
||||
def _get_category_from_path(skill_path: Path) -> str | None:
|
||||
"""
|
||||
Extract category from skill path based on directory structure.
|
||||
|
||||
|
||||
For paths like: ~/.hermes/skills/mlops/axolotl/SKILL.md -> "mlops"
|
||||
"""
|
||||
try:
|
||||
|
|
@ -179,134 +178,136 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
|
|||
def _estimate_tokens(content: str) -> int:
|
||||
"""
|
||||
Rough token estimate (4 chars per token average).
|
||||
|
||||
|
||||
Args:
|
||||
content: Text content
|
||||
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
return len(content) // 4
|
||||
|
||||
|
||||
def _parse_tags(tags_value) -> List[str]:
|
||||
def _parse_tags(tags_value) -> list[str]:
|
||||
"""
|
||||
Parse tags from frontmatter value.
|
||||
|
||||
|
||||
Handles:
|
||||
- Already-parsed list (from yaml.safe_load): [tag1, tag2]
|
||||
- String with brackets: "[tag1, tag2]"
|
||||
- Comma-separated string: "tag1, tag2"
|
||||
|
||||
|
||||
Args:
|
||||
tags_value: Raw tags value — may be a list or string
|
||||
|
||||
|
||||
Returns:
|
||||
List of tag strings
|
||||
"""
|
||||
if not tags_value:
|
||||
return []
|
||||
|
||||
|
||||
# yaml.safe_load already returns a list for [tag1, tag2]
|
||||
if isinstance(tags_value, list):
|
||||
return [str(t).strip() for t in tags_value if t]
|
||||
|
||||
|
||||
# String fallback — handle bracket-wrapped or comma-separated
|
||||
tags_value = str(tags_value).strip()
|
||||
if tags_value.startswith('[') and tags_value.endswith(']'):
|
||||
if tags_value.startswith("[") and tags_value.endswith("]"):
|
||||
tags_value = tags_value[1:-1]
|
||||
|
||||
return [t.strip().strip('"\'') for t in tags_value.split(',') if t.strip()]
|
||||
|
||||
return [t.strip().strip("\"'") for t in tags_value.split(",") if t.strip()]
|
||||
|
||||
|
||||
def _find_all_skills() -> List[Dict[str, Any]]:
|
||||
def _find_all_skills() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Recursively find all skills in ~/.hermes/skills/.
|
||||
|
||||
|
||||
Returns metadata for progressive disclosure (tier 1):
|
||||
- name, description, category
|
||||
|
||||
|
||||
Returns:
|
||||
List of skill metadata dicts
|
||||
"""
|
||||
skills = []
|
||||
|
||||
|
||||
if not SKILLS_DIR.exists():
|
||||
return skills
|
||||
|
||||
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
|
||||
continue
|
||||
|
||||
|
||||
skill_dir = skill_md.parent
|
||||
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
|
||||
name = frontmatter.get('name', skill_dir.name)[:MAX_NAME_LENGTH]
|
||||
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH]
|
||||
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
description = line
|
||||
break
|
||||
|
||||
|
||||
if len(description) > MAX_DESCRIPTION_LENGTH:
|
||||
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
category = _get_category_from_path(skill_md)
|
||||
|
||||
skills.append({
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": category,
|
||||
})
|
||||
|
||||
|
||||
skills.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": category,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def _load_category_description(category_dir: Path) -> Optional[str]:
|
||||
def _load_category_description(category_dir: Path) -> str | None:
|
||||
"""
|
||||
Load category description from DESCRIPTION.md if it exists.
|
||||
|
||||
|
||||
Args:
|
||||
category_dir: Path to the category directory
|
||||
|
||||
|
||||
Returns:
|
||||
Description string or None if not found
|
||||
"""
|
||||
desc_file = category_dir / "DESCRIPTION.md"
|
||||
if not desc_file.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
content = desc_file.read_text(encoding='utf-8')
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
# Parse frontmatter if present
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
|
||||
# Prefer frontmatter description, fall back to first non-header line
|
||||
description = frontmatter.get('description', '')
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
description = line
|
||||
break
|
||||
|
||||
|
||||
# Truncate to reasonable length
|
||||
if len(description) > MAX_DESCRIPTION_LENGTH:
|
||||
description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
description = description[: MAX_DESCRIPTION_LENGTH - 3] + "..."
|
||||
|
||||
return description if description else None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -315,26 +316,24 @@ def _load_category_description(category_dir: Path) -> Optional[str]:
|
|||
def skills_categories(verbose: bool = False, task_id: str = None) -> str:
|
||||
"""
|
||||
List available skill categories with descriptions (progressive disclosure tier 0).
|
||||
|
||||
|
||||
Returns category names and descriptions for efficient discovery before drilling down.
|
||||
Categories can have a DESCRIPTION.md file with a description frontmatter field
|
||||
or first paragraph to explain what skills are in that category.
|
||||
|
||||
|
||||
Args:
|
||||
verbose: If True, include skill counts per category (default: False, but currently always included)
|
||||
task_id: Optional task identifier (unused, for API consistency)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON string with list of categories and their descriptions
|
||||
"""
|
||||
try:
|
||||
if not SKILLS_DIR.exists():
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"categories": [],
|
||||
"message": "No skills directory found."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": True, "categories": [], "message": "No skills directory found."}, ensure_ascii=False
|
||||
)
|
||||
|
||||
category_dirs = {}
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
category = _get_category_from_path(skill_md)
|
||||
|
|
@ -342,121 +341,125 @@ def skills_categories(verbose: bool = False, task_id: str = None) -> str:
|
|||
category_dir = SKILLS_DIR / category
|
||||
if category not in category_dirs:
|
||||
category_dirs[category] = category_dir
|
||||
|
||||
|
||||
categories = []
|
||||
for name in sorted(category_dirs.keys()):
|
||||
category_dir = category_dirs[name]
|
||||
description = _load_category_description(category_dir)
|
||||
skill_count = sum(1 for _ in category_dir.rglob("SKILL.md"))
|
||||
|
||||
|
||||
cat_entry = {"name": name, "skill_count": skill_count}
|
||||
if description:
|
||||
cat_entry["description"] = description
|
||||
categories.append(cat_entry)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"categories": categories,
|
||||
"hint": "If a category is relevant to your task, use skills_list with that category to see available skills"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"categories": categories,
|
||||
"hint": "If a category is relevant to your task, use skills_list with that category to see available skills",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def skills_list(category: str = None, task_id: str = None) -> str:
|
||||
"""
|
||||
List all available skills (progressive disclosure tier 1 - minimal metadata).
|
||||
|
||||
Returns only name + description to minimize token usage. Use skill_view() to
|
||||
|
||||
Returns only name + description to minimize token usage. Use skill_view() to
|
||||
load full content, tags, related files, etc.
|
||||
|
||||
|
||||
Args:
|
||||
category: Optional category filter (e.g., "mlops")
|
||||
task_id: Optional task identifier (unused, for API consistency)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON string with minimal skill info: name, description, category
|
||||
"""
|
||||
try:
|
||||
if not SKILLS_DIR.exists():
|
||||
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"skills": [],
|
||||
"categories": [],
|
||||
"message": "No skills found. Skills directory created at ~/.hermes/skills/"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"skills": [],
|
||||
"categories": [],
|
||||
"message": "No skills found. Skills directory created at ~/.hermes/skills/",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Find all skills
|
||||
all_skills = _find_all_skills()
|
||||
|
||||
|
||||
if not all_skills:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"skills": [],
|
||||
"categories": [],
|
||||
"message": "No skills found in skills/ directory."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": True, "skills": [], "categories": [], "message": "No skills found in skills/ directory."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Filter by category if specified
|
||||
if category:
|
||||
all_skills = [s for s in all_skills if s.get("category") == category]
|
||||
|
||||
|
||||
# Sort by category then name
|
||||
all_skills.sort(key=lambda s: (s.get("category") or "", s["name"]))
|
||||
|
||||
|
||||
# Extract unique categories
|
||||
categories = sorted(set(s.get("category") for s in all_skills if s.get("category")))
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"skills": all_skills,
|
||||
"categories": categories,
|
||||
"count": len(all_skills),
|
||||
"hint": "Use skill_view(name) to see full content, tags, and linked files"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"skills": all_skills,
|
||||
"categories": categories,
|
||||
"count": len(all_skills),
|
||||
"hint": "Use skill_view(name) to see full content, tags, and linked files",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
||||
"""
|
||||
View the content of a skill or a specific file within a skill directory.
|
||||
|
||||
|
||||
Args:
|
||||
name: Name or path of the skill (e.g., "axolotl" or "03-fine-tuning/axolotl")
|
||||
file_path: Optional path to a specific file within the skill (e.g., "references/api.md")
|
||||
task_id: Optional task identifier (unused, for API consistency)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON string with skill content or error message
|
||||
"""
|
||||
try:
|
||||
if not SKILLS_DIR.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Skills directory does not exist yet. It will be created on first install."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Skills directory does not exist yet. It will be created on first install.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
skill_dir = None
|
||||
skill_md = None
|
||||
|
||||
|
||||
# Try direct path first (e.g., "mlops/axolotl")
|
||||
direct_path = SKILLS_DIR / name
|
||||
if direct_path.is_dir() and (direct_path / "SKILL.md").exists():
|
||||
skill_dir = direct_path
|
||||
skill_md = direct_path / "SKILL.md"
|
||||
elif direct_path.with_suffix('.md').exists():
|
||||
skill_md = direct_path.with_suffix('.md')
|
||||
|
||||
elif direct_path.with_suffix(".md").exists():
|
||||
skill_md = direct_path.with_suffix(".md")
|
||||
|
||||
# Search by directory name
|
||||
if not skill_md:
|
||||
for found_skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
|
|
@ -464,64 +467,70 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
skill_dir = found_skill_md.parent
|
||||
skill_md = found_skill_md
|
||||
break
|
||||
|
||||
|
||||
# Legacy: flat .md files
|
||||
if not skill_md:
|
||||
for found_md in SKILLS_DIR.rglob(f"{name}.md"):
|
||||
if found_md.name != "SKILL.md":
|
||||
skill_md = found_md
|
||||
break
|
||||
|
||||
|
||||
if not skill_md or not skill_md.exists():
|
||||
# List available skills in error message
|
||||
all_skills = _find_all_skills()
|
||||
available = [s["name"] for s in all_skills[:20]] # Limit to 20
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Skill '{name}' not found.",
|
||||
"available_skills": available,
|
||||
"hint": "Use skills_list to see all available skills"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Skill '{name}' not found.",
|
||||
"available_skills": available,
|
||||
"hint": "Use skills_list to see all available skills",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# If a specific file path is requested, read that instead
|
||||
if file_path and skill_dir:
|
||||
# Security: Prevent path traversal attacks
|
||||
normalized_path = Path(file_path)
|
||||
if ".." in normalized_path.parts:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Path traversal ('..') is not allowed.",
|
||||
"hint": "Use a relative path within the skill directory"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Path traversal ('..') is not allowed.",
|
||||
"hint": "Use a relative path within the skill directory",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
target_file = skill_dir / file_path
|
||||
|
||||
|
||||
# Security: Verify resolved path is still within skill directory
|
||||
try:
|
||||
resolved = target_file.resolve()
|
||||
skill_dir_resolved = skill_dir.resolve()
|
||||
if not resolved.is_relative_to(skill_dir_resolved):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Path escapes skill directory boundary.",
|
||||
"hint": "Use a relative path within the skill directory"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Path escapes skill directory boundary.",
|
||||
"hint": "Use a relative path within the skill directory",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except (OSError, ValueError):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Invalid file path: '{file_path}'",
|
||||
"hint": "Use a valid relative path within the skill directory"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Invalid file path: '{file_path}'",
|
||||
"hint": "Use a valid relative path within the skill directory",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if not target_file.exists():
|
||||
# List available files in the skill directory, organized by type
|
||||
available_files = {
|
||||
"references": [],
|
||||
"templates": [],
|
||||
"assets": [],
|
||||
"scripts": [],
|
||||
"other": []
|
||||
}
|
||||
|
||||
available_files = {"references": [], "templates": [], "assets": [], "scripts": [], "other": []}
|
||||
|
||||
# Scan for all readable files
|
||||
for f in skill_dir.rglob("*"):
|
||||
if f.is_file() and f.name != "SKILL.md":
|
||||
|
|
@ -534,82 +543,85 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
available_files["assets"].append(rel)
|
||||
elif rel.startswith("scripts/"):
|
||||
available_files["scripts"].append(rel)
|
||||
elif f.suffix in ['.md', '.py', '.yaml', '.yml', '.json', '.tex', '.sh']:
|
||||
elif f.suffix in [".md", ".py", ".yaml", ".yml", ".json", ".tex", ".sh"]:
|
||||
available_files["other"].append(rel)
|
||||
|
||||
|
||||
# Remove empty categories
|
||||
available_files = {k: v for k, v in available_files.items() if v}
|
||||
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"File '{file_path}' not found in skill '{name}'.",
|
||||
"available_files": available_files,
|
||||
"hint": "Use one of the available file paths listed above"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"File '{file_path}' not found in skill '{name}'.",
|
||||
"available_files": available_files,
|
||||
"hint": "Use one of the available file paths listed above",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Read the file content
|
||||
try:
|
||||
content = target_file.read_text(encoding='utf-8')
|
||||
content = target_file.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Binary file - return info about it instead
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"name": name,
|
||||
"file": file_path,
|
||||
"content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]",
|
||||
"is_binary": True
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"name": name,
|
||||
"file": file_path,
|
||||
"content": content,
|
||||
"file_type": target_file.suffix
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"name": name,
|
||||
"file": file_path,
|
||||
"content": f"[Binary file: {target_file.name}, size: {target_file.stat().st_size} bytes]",
|
||||
"is_binary": True,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{"success": True, "name": name, "file": file_path, "content": content, "file_type": target_file.suffix},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Read the main skill content
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
|
||||
# Get reference, template, asset, and script files if this is a directory-based skill
|
||||
reference_files = []
|
||||
template_files = []
|
||||
asset_files = []
|
||||
script_files = []
|
||||
|
||||
|
||||
if skill_dir:
|
||||
references_dir = skill_dir / "references"
|
||||
if references_dir.exists():
|
||||
reference_files = [str(f.relative_to(skill_dir)) for f in references_dir.glob("*.md")]
|
||||
|
||||
|
||||
templates_dir = skill_dir / "templates"
|
||||
if templates_dir.exists():
|
||||
for ext in ['*.md', '*.py', '*.yaml', '*.yml', '*.json', '*.tex', '*.sh']:
|
||||
for ext in ["*.md", "*.py", "*.yaml", "*.yml", "*.json", "*.tex", "*.sh"]:
|
||||
template_files.extend([str(f.relative_to(skill_dir)) for f in templates_dir.rglob(ext)])
|
||||
|
||||
|
||||
# assets/ — agentskills.io standard directory for supplementary files
|
||||
assets_dir = skill_dir / "assets"
|
||||
if assets_dir.exists():
|
||||
for f in assets_dir.rglob("*"):
|
||||
if f.is_file():
|
||||
asset_files.append(str(f.relative_to(skill_dir)))
|
||||
|
||||
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
if scripts_dir.exists():
|
||||
for ext in ['*.py', '*.sh', '*.bash', '*.js', '*.ts', '*.rb']:
|
||||
for ext in ["*.py", "*.sh", "*.bash", "*.js", "*.ts", "*.rb"]:
|
||||
script_files.extend([str(f.relative_to(skill_dir)) for f in scripts_dir.glob(ext)])
|
||||
|
||||
|
||||
# Read tags/related_skills with backward compat:
|
||||
# Check metadata.hermes.* first (agentskills.io convention), fall back to top-level
|
||||
hermes_meta = {}
|
||||
metadata = frontmatter.get('metadata')
|
||||
metadata = frontmatter.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
hermes_meta = metadata.get('hermes', {}) or {}
|
||||
|
||||
tags = _parse_tags(hermes_meta.get('tags') or frontmatter.get('tags', ''))
|
||||
related_skills = _parse_tags(hermes_meta.get('related_skills') or frontmatter.get('related_skills', ''))
|
||||
|
||||
hermes_meta = metadata.get("hermes", {}) or {}
|
||||
|
||||
tags = _parse_tags(hermes_meta.get("tags") or frontmatter.get("tags", ""))
|
||||
related_skills = _parse_tags(hermes_meta.get("related_skills") or frontmatter.get("related_skills", ""))
|
||||
|
||||
# Build linked files structure for clear discovery
|
||||
linked_files = {}
|
||||
if reference_files:
|
||||
|
|
@ -620,34 +632,33 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
linked_files["assets"] = asset_files
|
||||
if script_files:
|
||||
linked_files["scripts"] = script_files
|
||||
|
||||
|
||||
rel_path = str(skill_md.relative_to(SKILLS_DIR))
|
||||
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"name": frontmatter.get('name', skill_md.stem if not skill_dir else skill_dir.name),
|
||||
"description": frontmatter.get('description', ''),
|
||||
"name": frontmatter.get("name", skill_md.stem if not skill_dir else skill_dir.name),
|
||||
"description": frontmatter.get("description", ""),
|
||||
"tags": tags,
|
||||
"related_skills": related_skills,
|
||||
"content": content,
|
||||
"path": rel_path,
|
||||
"linked_files": linked_files if linked_files else None,
|
||||
"usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'" if linked_files else None
|
||||
"usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'"
|
||||
if linked_files
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
# Surface agentskills.io optional fields when present
|
||||
if frontmatter.get('compatibility'):
|
||||
result["compatibility"] = frontmatter['compatibility']
|
||||
if frontmatter.get("compatibility"):
|
||||
result["compatibility"] = frontmatter["compatibility"]
|
||||
if isinstance(metadata, dict):
|
||||
result["metadata"] = metadata
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps({"success": False, "error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
# Tool description for model_tools.py
|
||||
|
|
@ -669,7 +680,7 @@ if __name__ == "__main__":
|
|||
"""Test the skills tool"""
|
||||
print("🎯 Skills Tool Test")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Test listing skills
|
||||
print("\n📋 Listing all skills:")
|
||||
result = json.loads(skills_list())
|
||||
|
|
@ -678,12 +689,12 @@ if __name__ == "__main__":
|
|||
print(f"Categories: {result.get('categories', [])}")
|
||||
print("\nFirst 10 skills:")
|
||||
for skill in result["skills"][:10]:
|
||||
cat = f"[{skill['category']}] " if skill.get('category') else ""
|
||||
refs = f" (+{len(skill['reference_files'])} refs)" if skill.get('reference_files') else ""
|
||||
cat = f"[{skill['category']}] " if skill.get("category") else ""
|
||||
refs = f" (+{len(skill['reference_files'])} refs)" if skill.get("reference_files") else ""
|
||||
print(f" • {cat}{skill['name']}: {skill['description'][:60]}...{refs}")
|
||||
else:
|
||||
print(f"Error: {result['error']}")
|
||||
|
||||
|
||||
# Test viewing a skill
|
||||
print("\n📖 Viewing skill 'axolotl':")
|
||||
result = json.loads(skill_view("axolotl"))
|
||||
|
|
@ -691,11 +702,11 @@ if __name__ == "__main__":
|
|||
print(f"Name: {result['name']}")
|
||||
print(f"Description: {result.get('description', 'N/A')[:100]}...")
|
||||
print(f"Content length: {len(result['content'])} chars")
|
||||
if result.get('reference_files'):
|
||||
if result.get("reference_files"):
|
||||
print(f"Reference files: {result['reference_files']}")
|
||||
else:
|
||||
print(f"Error: {result['error']}")
|
||||
|
||||
|
||||
# Test viewing a reference file
|
||||
print("\n📄 Viewing reference file 'axolotl/references/dataset-formats.md':")
|
||||
result = json.loads(skill_view("axolotl", "references/dataset-formats.md"))
|
||||
|
|
@ -717,14 +728,9 @@ SKILLS_LIST_SCHEMA = {
|
|||
"description": "List available skills (name + description). Use skill_view(name) to load full content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Optional category filter to narrow results"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"properties": {"category": {"type": "string", "description": "Optional category filter to narrow results"}},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
SKILL_VIEW_SCHEMA = {
|
||||
|
|
@ -733,17 +739,14 @@ SKILL_VIEW_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The skill name (use skills_list to see available skills)"
|
||||
},
|
||||
"name": {"type": "string", "description": "The skill name (use skills_list to see available skills)"},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content."
|
||||
}
|
||||
"description": "OPTIONAL: Path to a linked file within the skill (e.g., 'references/api.md', 'templates/config.yaml', 'scripts/validate.py'). Omit to get the main SKILL.md content.",
|
||||
},
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
|
||||
registry.register(
|
||||
|
|
|
|||
|
|
@ -26,20 +26,22 @@ Usage:
|
|||
result = terminal_tool("python server.py", background=True)
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import atexit
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from tools.interrupt import (
|
||||
_interrupt_event, # noqa: F401 — re-exported to environments/local.py
|
||||
is_interrupted, # noqa: F401 — re-exported
|
||||
)
|
||||
from tools.interrupt import set_interrupt as set_interrupt_event # noqa: F401 — re-exported
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,7 +51,6 @@ logger = logging.getLogger(__name__)
|
|||
# The terminal tool polls this during command execution so it can kill
|
||||
# long-running subprocesses immediately instead of blocking until timeout.
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.interrupt import set_interrupt as set_interrupt_event, is_interrupted, _interrupt_event
|
||||
|
||||
|
||||
# Add mini-swe-agent to path if not installed
|
||||
|
|
@ -65,7 +66,6 @@ if mini_swe_path.exists():
|
|||
# Singularity helpers (scratch dir, SIF cache) now live in tools/environments/singularity.py
|
||||
from tools.environments.singularity import _get_scratch_dir
|
||||
|
||||
|
||||
# Disk usage warning threshold (in GB)
|
||||
DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "500"))
|
||||
|
||||
|
|
@ -73,28 +73,32 @@ DISK_USAGE_WARNING_THRESHOLD_GB = float(os.getenv("TERMINAL_DISK_WARNING_GB", "5
|
|||
def _check_disk_usage_warning():
|
||||
"""Check if total disk usage exceeds warning threshold."""
|
||||
scratch_dir = _get_scratch_dir()
|
||||
|
||||
|
||||
try:
|
||||
# Get total size of hermes directories
|
||||
total_bytes = 0
|
||||
import glob
|
||||
|
||||
for path in glob.glob(str(scratch_dir / "hermes-*")):
|
||||
for f in Path(path).rglob('*'):
|
||||
for f in Path(path).rglob("*"):
|
||||
if f.is_file():
|
||||
try:
|
||||
total_bytes += f.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
total_gb = total_bytes / (1024 ** 3)
|
||||
|
||||
|
||||
total_gb = total_bytes / (1024**3)
|
||||
|
||||
if total_gb > DISK_USAGE_WARNING_THRESHOLD_GB:
|
||||
logger.warning("Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().",
|
||||
total_gb, DISK_USAGE_WARNING_THRESHOLD_GB)
|
||||
logger.warning(
|
||||
"Disk usage (%.1fGB) exceeds threshold (%.0fGB). Consider running cleanup_all_environments().",
|
||||
total_gb,
|
||||
DISK_USAGE_WARNING_THRESHOLD_GB,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -121,59 +125,59 @@ def set_approval_callback(cb):
|
|||
global _approval_callback
|
||||
_approval_callback = cb
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dangerous Command Approval System
|
||||
# =============================================================================
|
||||
|
||||
# Dangerous command detection + approval now consolidated in tools/approval.py
|
||||
from tools.approval import (
|
||||
detect_dangerous_command as _detect_dangerous_command,
|
||||
check_dangerous_command as _check_dangerous_command_impl,
|
||||
load_permanent_allowlist as _load_permanent_allowlist,
|
||||
DANGEROUS_PATTERNS,
|
||||
)
|
||||
|
||||
|
||||
def _check_dangerous_command(command: str, env_type: str) -> dict:
|
||||
"""Delegate to the consolidated approval module, passing the CLI callback."""
|
||||
return _check_dangerous_command_impl(command, env_type,
|
||||
approval_callback=_approval_callback)
|
||||
return _check_dangerous_command_impl(command, env_type, approval_callback=_approval_callback)
|
||||
|
||||
|
||||
def _handle_sudo_failure(output: str, env_type: str) -> str:
|
||||
"""
|
||||
Check for sudo failure and add helpful message for messaging contexts.
|
||||
|
||||
|
||||
Returns enhanced output if sudo failed in messaging context, else original.
|
||||
"""
|
||||
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
|
||||
|
||||
|
||||
if not is_gateway:
|
||||
return output
|
||||
|
||||
|
||||
# Check for sudo failure indicators
|
||||
sudo_failures = [
|
||||
"sudo: a password is required",
|
||||
"sudo: no tty present",
|
||||
"sudo: a terminal is required",
|
||||
]
|
||||
|
||||
|
||||
for failure in sudo_failures:
|
||||
if failure in output:
|
||||
return output + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine."
|
||||
|
||||
return (
|
||||
output
|
||||
+ "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine."
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
||||
"""
|
||||
Prompt user for sudo password with timeout.
|
||||
|
||||
|
||||
Returns the password if entered, or empty string if:
|
||||
- User presses Enter without input (skip)
|
||||
- Timeout expires (45s default)
|
||||
- Any error occurs
|
||||
|
||||
|
||||
Only works in interactive mode (HERMES_INTERACTIVE=1).
|
||||
If a _sudo_password_callback is registered (by the CLI), delegates to it
|
||||
so the prompt integrates with prompt_toolkit's UI. Otherwise reads
|
||||
|
|
@ -181,7 +185,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
"""
|
||||
import sys
|
||||
import time as time_module
|
||||
|
||||
|
||||
# Use the registered callback when available (prompt_toolkit-compatible)
|
||||
if _sudo_password_callback is not None:
|
||||
try:
|
||||
|
|
@ -190,13 +194,14 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
return ""
|
||||
|
||||
result = {"password": None, "done": False}
|
||||
|
||||
|
||||
def read_password_thread():
|
||||
"""Read password from /dev/tty with echo disabled."""
|
||||
tty_fd = None
|
||||
old_attrs = None
|
||||
try:
|
||||
import termios
|
||||
|
||||
tty_fd = os.open("/dev/tty", os.O_RDONLY)
|
||||
old_attrs = termios.tcgetattr(tty_fd)
|
||||
new_attrs = termios.tcgetattr(tty_fd)
|
||||
|
|
@ -217,6 +222,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
if tty_fd is not None and old_attrs is not None:
|
||||
try:
|
||||
import termios as _termios
|
||||
|
||||
_termios.tcsetattr(tty_fd, _termios.TCSAFLUSH, old_attrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -226,11 +232,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
except Exception:
|
||||
pass
|
||||
result["done"] = True
|
||||
|
||||
|
||||
try:
|
||||
os.environ["HERMES_SPINNER_PAUSE"] = "1"
|
||||
time_module.sleep(0.2)
|
||||
|
||||
|
||||
print()
|
||||
print("┌" + "─" * 58 + "┐")
|
||||
print("│ 🔐 SUDO PASSWORD REQUIRED" + " " * 30 + "│")
|
||||
|
|
@ -241,11 +247,11 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
print("└" + "─" * 58 + "┘")
|
||||
print()
|
||||
print(" Password (hidden): ", end="", flush=True)
|
||||
|
||||
|
||||
password_thread = threading.Thread(target=read_password_thread, daemon=True)
|
||||
password_thread.start()
|
||||
password_thread.join(timeout=timeout_seconds)
|
||||
|
||||
|
||||
if result["done"]:
|
||||
password = result["password"] or ""
|
||||
print() # newline after hidden input
|
||||
|
|
@ -262,7 +268,7 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
print()
|
||||
sys.stdout.flush()
|
||||
return ""
|
||||
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
print(" ⏭ Cancelled - continuing without sudo")
|
||||
|
|
@ -281,29 +287,29 @@ def _prompt_for_sudo_password(timeout_seconds: int = 45) -> str:
|
|||
def _transform_sudo_command(command: str) -> str:
|
||||
"""
|
||||
Transform sudo commands to use -S flag if SUDO_PASSWORD is available.
|
||||
|
||||
|
||||
This is a shared helper used by all execution environments to provide
|
||||
consistent sudo handling across local, SSH, and container environments.
|
||||
|
||||
|
||||
If SUDO_PASSWORD is set (via env, config, or interactive prompt):
|
||||
'sudo apt install curl' -> password piped via sudo -S
|
||||
|
||||
|
||||
If SUDO_PASSWORD is not set and in interactive mode (HERMES_INTERACTIVE=1):
|
||||
Prompts user for password with 45s timeout, caches for session.
|
||||
|
||||
|
||||
If SUDO_PASSWORD is not set and NOT interactive:
|
||||
Command runs as-is (fails gracefully with "sudo: a password is required").
|
||||
"""
|
||||
global _cached_sudo_password
|
||||
import re
|
||||
|
||||
|
||||
# Check if command even contains sudo
|
||||
if not re.search(r'\bsudo\b', command):
|
||||
if not re.search(r"\bsudo\b", command):
|
||||
return command # No sudo in command, return as-is
|
||||
|
||||
|
||||
# Try to get password from: env var -> session cache -> interactive prompt
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "") or _cached_sudo_password
|
||||
|
||||
|
||||
if not sudo_password:
|
||||
# No password configured - check if we're in interactive mode
|
||||
if os.getenv("HERMES_INTERACTIVE"):
|
||||
|
|
@ -311,30 +317,30 @@ def _transform_sudo_command(command: str) -> str:
|
|||
sudo_password = _prompt_for_sudo_password(timeout_seconds=45)
|
||||
if sudo_password:
|
||||
_cached_sudo_password = sudo_password # Cache for session
|
||||
|
||||
|
||||
if not sudo_password:
|
||||
return command # No password, let it fail gracefully
|
||||
|
||||
|
||||
def replace_sudo(match):
|
||||
# Replace 'sudo' with password-piped version
|
||||
# The -S flag makes sudo read password from stdin
|
||||
# The -p '' suppresses the password prompt
|
||||
# Use shlex.quote() to prevent shell injection via password content
|
||||
import shlex
|
||||
|
||||
return f"echo {shlex.quote(sudo_password)} | sudo -S -p ''"
|
||||
|
||||
|
||||
# Match 'sudo' at word boundaries (not 'visudo' or 'sudoers')
|
||||
# This handles: sudo, sudo -flag, etc.
|
||||
return re.sub(r'\bsudo\b', replace_sudo, command)
|
||||
return re.sub(r"\bsudo\b", replace_sudo, command)
|
||||
|
||||
|
||||
# Environment classes now live in tools/environments/
|
||||
from tools.environments.docker import DockerEnvironment as _DockerEnvironment
|
||||
from tools.environments.local import LocalEnvironment as _LocalEnvironment
|
||||
from tools.environments.modal import ModalEnvironment as _ModalEnvironment
|
||||
from tools.environments.singularity import SingularityEnvironment as _SingularityEnvironment
|
||||
from tools.environments.ssh import SSHEnvironment as _SSHEnvironment
|
||||
from tools.environments.docker import DockerEnvironment as _DockerEnvironment
|
||||
from tools.environments.modal import ModalEnvironment as _ModalEnvironment
|
||||
|
||||
|
||||
# Tool description for LLM
|
||||
TERMINAL_TOOL_DESCRIPTION = """Execute shell commands on a Linux environment. Filesystem persists between calls.
|
||||
|
|
@ -356,10 +362,10 @@ Do NOT use vim/nano/interactive tools without pty=true — they hang without a p
|
|||
"""
|
||||
|
||||
# Global state for environment lifecycle management
|
||||
_active_environments: Dict[str, Any] = {}
|
||||
_last_activity: Dict[str, float] = {}
|
||||
_active_environments: dict[str, Any] = {}
|
||||
_last_activity: dict[str, float] = {}
|
||||
_env_lock = threading.Lock()
|
||||
_creation_locks: Dict[str, threading.Lock] = {} # Per-task locks for sandbox creation
|
||||
_creation_locks: dict[str, threading.Lock] = {} # Per-task locks for sandbox creation
|
||||
_creation_locks_lock = threading.Lock() # Protects _creation_locks dict itself
|
||||
_cleanup_thread = None
|
||||
_cleanup_running = False
|
||||
|
|
@ -372,10 +378,10 @@ _cleanup_running = False
|
|||
#
|
||||
# This is never exposed to the model -- only infrastructure code calls it.
|
||||
# Thread-safe because each task_id is unique per rollout.
|
||||
_task_env_overrides: Dict[str, Dict[str, Any]] = {}
|
||||
_task_env_overrides: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def register_task_env_overrides(task_id: str, overrides: Dict[str, Any]):
|
||||
def register_task_env_overrides(task_id: str, overrides: dict[str, Any]):
|
||||
"""
|
||||
Register environment overrides for a specific task/rollout.
|
||||
|
||||
|
|
@ -402,13 +408,14 @@ def clear_task_env_overrides(task_id: str):
|
|||
"""
|
||||
_task_env_overrides.pop(task_id, None)
|
||||
|
||||
|
||||
# Configuration from environment variables
|
||||
def _get_env_config() -> Dict[str, Any]:
|
||||
def _get_env_config() -> dict[str, Any]:
|
||||
"""Get terminal environment configuration from environment variables."""
|
||||
# Default image with Python and Node.js for maximum compatibility
|
||||
default_image = "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
env_type = os.getenv("TERMINAL_ENV", "local")
|
||||
|
||||
|
||||
# Default cwd: local uses the host's current directory, everything
|
||||
# else starts in the user's home (~ resolves to whatever account
|
||||
# is running inside the container/remote).
|
||||
|
|
@ -416,7 +423,7 @@ def _get_env_config() -> Dict[str, Any]:
|
|||
default_cwd = os.getcwd()
|
||||
else:
|
||||
default_cwd = "~"
|
||||
|
||||
|
||||
# Read TERMINAL_CWD but sanity-check it for container backends.
|
||||
# If the CWD looks like a host-local path that can't exist inside a
|
||||
# container/sandbox, fall back to the backend's own default. This
|
||||
|
|
@ -426,9 +433,12 @@ def _get_env_config() -> Dict[str, Any]:
|
|||
if env_type in ("modal", "docker", "singularity", "daytona") and cwd:
|
||||
host_prefixes = ("/Users/", "C:\\", "C:/")
|
||||
if any(cwd.startswith(p) for p in host_prefixes) and cwd != default_cwd:
|
||||
logger.info("Ignoring TERMINAL_CWD=%r for %s backend "
|
||||
"(host path won't exist in sandbox). Using %r instead.",
|
||||
cwd, env_type, default_cwd)
|
||||
logger.info(
|
||||
"Ignoring TERMINAL_CWD=%r for %s backend (host path won't exist in sandbox). Using %r instead.",
|
||||
cwd,
|
||||
env_type,
|
||||
default_cwd,
|
||||
)
|
||||
cwd = default_cwd
|
||||
|
||||
return {
|
||||
|
|
@ -447,19 +457,25 @@ def _get_env_config() -> Dict[str, Any]:
|
|||
"ssh_key": os.getenv("TERMINAL_SSH_KEY", ""),
|
||||
# Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh)
|
||||
"container_cpu": float(os.getenv("TERMINAL_CONTAINER_CPU", "1")),
|
||||
"container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB)
|
||||
"container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB)
|
||||
"container_memory": int(os.getenv("TERMINAL_CONTAINER_MEMORY", "5120")), # MB (default 5GB)
|
||||
"container_disk": int(os.getenv("TERMINAL_CONTAINER_DISK", "51200")), # MB (default 50GB)
|
||||
"container_persistent": os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("true", "1", "yes"),
|
||||
"docker_volumes": json.loads(os.getenv("TERMINAL_DOCKER_VOLUMES", "[]")),
|
||||
}
|
||||
|
||||
|
||||
def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
||||
ssh_config: dict = None, container_config: dict = None,
|
||||
task_id: str = "default"):
|
||||
def _create_environment(
|
||||
env_type: str,
|
||||
image: str,
|
||||
cwd: str,
|
||||
timeout: int,
|
||||
ssh_config: dict = None,
|
||||
container_config: dict = None,
|
||||
task_id: str = "default",
|
||||
):
|
||||
"""
|
||||
Create an execution environment from mini-swe-agent.
|
||||
|
||||
|
||||
Args:
|
||||
env_type: One of "local", "docker", "singularity", "modal", "daytona", "ssh"
|
||||
image: Docker/Singularity/Modal image name (ignored for local/ssh)
|
||||
|
|
@ -468,7 +484,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
ssh_config: SSH connection config (for env_type="ssh")
|
||||
container_config: Resource config for container backends (cpu, memory, disk, persistent)
|
||||
task_id: Task identifier for environment reuse and snapshot keying
|
||||
|
||||
|
||||
Returns:
|
||||
Environment instance with execute() method
|
||||
"""
|
||||
|
|
@ -481,22 +497,32 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
|
||||
if env_type == "local":
|
||||
return _LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
|
||||
elif env_type == "docker":
|
||||
return _DockerEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
cpu=cpu, memory=memory, disk=disk,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
cpu=cpu,
|
||||
memory=memory,
|
||||
disk=disk,
|
||||
persistent_filesystem=persistent,
|
||||
task_id=task_id,
|
||||
volumes=volumes,
|
||||
)
|
||||
|
||||
|
||||
elif env_type == "singularity":
|
||||
return _SingularityEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
cpu=cpu, memory=memory, disk=disk,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
cpu=cpu,
|
||||
memory=memory,
|
||||
disk=disk,
|
||||
persistent_filesystem=persistent,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
elif env_type == "modal":
|
||||
sandbox_kwargs = {}
|
||||
if cpu > 0:
|
||||
|
|
@ -505,20 +531,29 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
sandbox_kwargs["memory"] = memory
|
||||
if disk > 0:
|
||||
sandbox_kwargs["ephemeral_disk"] = disk
|
||||
|
||||
|
||||
return _ModalEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
persistent_filesystem=persistent,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
elif env_type == "daytona":
|
||||
# Lazy import so daytona SDK is only required when backend is selected.
|
||||
from tools.environments.daytona import DaytonaEnvironment as _DaytonaEnvironment
|
||||
|
||||
return _DaytonaEnvironment(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
cpu=int(cpu), memory=memory, disk=disk,
|
||||
persistent_filesystem=persistent, task_id=task_id,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
cpu=int(cpu),
|
||||
memory=memory,
|
||||
disk=disk,
|
||||
persistent_filesystem=persistent,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
elif env_type == "ssh":
|
||||
|
|
@ -534,7 +569,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int,
|
|||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'")
|
||||
raise ValueError(
|
||||
f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'"
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_inactive_envs(lifetime_seconds: int = 300):
|
||||
|
|
@ -547,6 +584,7 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300):
|
|||
# background processes (their _last_activity gets refreshed to keep them alive).
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
for task_id in list(_last_activity.keys()):
|
||||
if process_registry.has_active_processes(task_id):
|
||||
_last_activity[task_id] = current_time # Keep sandbox alive
|
||||
|
|
@ -579,16 +617,17 @@ def _cleanup_inactive_envs(lifetime_seconds: int = 300):
|
|||
# ShellFileOperations from referencing a dead sandbox)
|
||||
try:
|
||||
from tools.file_tools import clear_file_ops_cache
|
||||
|
||||
clear_file_ops_cache(task_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if hasattr(env, 'cleanup'):
|
||||
if hasattr(env, "cleanup"):
|
||||
env.cleanup()
|
||||
elif hasattr(env, 'stop'):
|
||||
elif hasattr(env, "stop"):
|
||||
env.stop()
|
||||
elif hasattr(env, 'terminate'):
|
||||
elif hasattr(env, "terminate"):
|
||||
env.terminate()
|
||||
|
||||
logger.info("Cleaned up inactive environment for task: %s", task_id)
|
||||
|
|
@ -640,27 +679,28 @@ def _stop_cleanup_thread():
|
|||
pass
|
||||
|
||||
|
||||
def get_active_environments_info() -> Dict[str, Any]:
|
||||
def get_active_environments_info() -> dict[str, Any]:
|
||||
"""Get information about currently active environments."""
|
||||
info = {
|
||||
"count": len(_active_environments),
|
||||
"task_ids": list(_active_environments.keys()),
|
||||
"workdirs": {},
|
||||
}
|
||||
|
||||
|
||||
# Calculate total disk usage (per-task to avoid double-counting)
|
||||
total_size = 0
|
||||
for task_id in _active_environments.keys():
|
||||
for task_id in _active_environments:
|
||||
scratch_dir = _get_scratch_dir()
|
||||
pattern = f"hermes-*{task_id[:8]}*"
|
||||
import glob
|
||||
|
||||
for path in glob.glob(str(scratch_dir / pattern)):
|
||||
try:
|
||||
size = sum(f.stat().st_size for f in Path(path).rglob('*') if f.is_file())
|
||||
size = sum(f.stat().st_size for f in Path(path).rglob("*") if f.is_file())
|
||||
total_size += size
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
info["total_disk_usage_mb"] = round(total_size / (1024 * 1024), 2)
|
||||
return info
|
||||
|
||||
|
|
@ -668,27 +708,28 @@ def get_active_environments_info() -> Dict[str, Any]:
|
|||
def cleanup_all_environments():
|
||||
"""Clean up ALL active environments. Use with caution."""
|
||||
global _active_environments, _last_activity
|
||||
|
||||
|
||||
task_ids = list(_active_environments.keys())
|
||||
cleaned = 0
|
||||
|
||||
|
||||
for task_id in task_ids:
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.error("Error cleaning %s: %s", task_id, e, exc_info=True)
|
||||
|
||||
|
||||
# Also clean any orphaned directories
|
||||
scratch_dir = _get_scratch_dir()
|
||||
import glob
|
||||
|
||||
for path in glob.glob(str(scratch_dir / "hermes-*")):
|
||||
try:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
logger.info("Removed orphaned: %s", path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info("Cleaned %d environments", cleaned)
|
||||
return cleaned
|
||||
|
|
@ -713,6 +754,7 @@ def cleanup_vm(task_id: str):
|
|||
# Invalidate stale file_ops cache entry
|
||||
try:
|
||||
from tools.file_tools import clear_file_ops_cache
|
||||
|
||||
clear_file_ops_cache(task_id)
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
@ -721,11 +763,11 @@ def cleanup_vm(task_id: str):
|
|||
return
|
||||
|
||||
try:
|
||||
if hasattr(env, 'cleanup'):
|
||||
if hasattr(env, "cleanup"):
|
||||
env.cleanup()
|
||||
elif hasattr(env, 'stop'):
|
||||
elif hasattr(env, "stop"):
|
||||
env.stop()
|
||||
elif hasattr(env, 'terminate'):
|
||||
elif hasattr(env, "terminate"):
|
||||
env.terminate()
|
||||
|
||||
logger.info("Manually cleaned up environment for task: %s", task_id)
|
||||
|
|
@ -746,17 +788,18 @@ def _atexit_cleanup():
|
|||
logger.info("Shutting down %d remaining sandbox(es)...", count)
|
||||
cleanup_all_environments()
|
||||
|
||||
|
||||
atexit.register(_atexit_cleanup)
|
||||
|
||||
|
||||
def terminal_tool(
|
||||
command: str,
|
||||
background: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
task_id: Optional[str] = None,
|
||||
timeout: int | None = None,
|
||||
task_id: str | None = None,
|
||||
force: bool = False,
|
||||
workdir: Optional[str] = None,
|
||||
check_interval: Optional[int] = None,
|
||||
workdir: str | None = None,
|
||||
check_interval: int | None = None,
|
||||
pty: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -784,7 +827,7 @@ def terminal_tool(
|
|||
|
||||
# With custom timeout
|
||||
>>> result = terminal_tool(command="long_task.sh", timeout=300)
|
||||
|
||||
|
||||
# Force run after user confirmation
|
||||
# Note: force parameter is internal only, not exposed to model API
|
||||
"""
|
||||
|
|
@ -801,7 +844,7 @@ def terminal_tool(
|
|||
# Check per-task overrides (set by environments like TerminalBench2Env)
|
||||
# before falling back to global env var config
|
||||
overrides = _task_env_overrides.get(effective_task_id, {})
|
||||
|
||||
|
||||
# Select image based on env type, with per-task override support
|
||||
if env_type == "docker":
|
||||
image = overrides.get("docker_image") or config["docker_image"]
|
||||
|
|
@ -882,12 +925,15 @@ def terminal_tool(
|
|||
task_id=effective_task_id,
|
||||
)
|
||||
except ImportError as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
|
||||
"status": "disabled"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool disabled: mini-swe-agent not available ({e})",
|
||||
"status": "disabled",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
with _env_lock:
|
||||
_active_environments[effective_task_id] = new_env
|
||||
|
|
@ -902,27 +948,33 @@ def terminal_tool(
|
|||
if not approval["approved"]:
|
||||
# Check if this is an approval_required (gateway ask mode)
|
||||
if approval.get("status") == "approval_required":
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": approval.get("message", "Waiting for user approval"),
|
||||
"status": "approval_required",
|
||||
"command": approval.get("command", command),
|
||||
"description": approval.get("description", "dangerous command"),
|
||||
"pattern_key": approval.get("pattern_key", ""),
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": approval.get("message", "Waiting for user approval"),
|
||||
"status": "approval_required",
|
||||
"command": approval.get("command", command),
|
||||
"description": approval.get("description", "dangerous command"),
|
||||
"pattern_key": approval.get("pattern_key", ""),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
# Command was blocked - include the pattern category so the caller knows why
|
||||
desc = approval.get("description", "potentially dangerous operation")
|
||||
fallback_msg = (
|
||||
f"Command denied: matches '{desc}' pattern. "
|
||||
"Use the approval prompt to allow it, or rephrase the command."
|
||||
)
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": approval.get("message", fallback_msg),
|
||||
"status": "blocked"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": approval.get("message", fallback_msg),
|
||||
"status": "blocked",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Prepare command for execution
|
||||
if background:
|
||||
|
|
@ -940,7 +992,7 @@ def terminal_tool(
|
|||
cwd=effective_cwd,
|
||||
task_id=effective_task_id,
|
||||
session_key=session_key,
|
||||
env_vars=env.env if hasattr(env, 'env') else None,
|
||||
env_vars=env.env if hasattr(env, "env") else None,
|
||||
use_pty=pty,
|
||||
)
|
||||
else:
|
||||
|
|
@ -964,38 +1016,36 @@ def terminal_tool(
|
|||
max_timeout = effective_timeout
|
||||
if timeout and timeout > max_timeout:
|
||||
result_data["timeout_note"] = (
|
||||
f"Requested timeout {timeout}s was clamped to "
|
||||
f"configured limit of {max_timeout}s"
|
||||
f"Requested timeout {timeout}s was clamped to configured limit of {max_timeout}s"
|
||||
)
|
||||
|
||||
# Register check_interval watcher (gateway picks this up after agent run)
|
||||
if check_interval and background:
|
||||
effective_interval = max(30, check_interval)
|
||||
if check_interval < 30:
|
||||
result_data["check_interval_note"] = (
|
||||
f"Requested {check_interval}s raised to minimum 30s"
|
||||
)
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": effective_interval,
|
||||
"session_key": session_key,
|
||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
||||
})
|
||||
result_data["check_interval_note"] = f"Requested {check_interval}s raised to minimum 30s"
|
||||
process_registry.pending_watchers.append(
|
||||
{
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": effective_interval,
|
||||
"session_key": session_key,
|
||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps(result_data, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Failed to start background process: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"output": "", "exit_code": -1, "error": f"Failed to start background process: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
else:
|
||||
# Run foreground command with retry logic
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
result = None
|
||||
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
execute_kwargs = {"timeout": effective_timeout}
|
||||
|
|
@ -1005,39 +1055,61 @@ def terminal_tool(
|
|||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": 124,
|
||||
"error": f"Command timed out after {effective_timeout} seconds"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"output": "",
|
||||
"exit_code": 124,
|
||||
"error": f"Command timed out after {effective_timeout} seconds",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Retry on transient errors
|
||||
if retry_count < max_retries:
|
||||
retry_count += 1
|
||||
wait_time = 2 ** retry_count
|
||||
logger.warning("Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
wait_time, retry_count, max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
|
||||
wait_time = 2**retry_count
|
||||
logger.warning(
|
||||
"Execution error, retrying in %ds (attempt %d/%d) - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
wait_time,
|
||||
retry_count,
|
||||
max_retries,
|
||||
command[:200],
|
||||
type(e).__name__,
|
||||
e,
|
||||
effective_task_id,
|
||||
env_type,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
logger.error("Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
max_retries, command[:200], type(e).__name__, e, effective_task_id, env_type)
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Command execution failed: {type(e).__name__}: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
logger.error(
|
||||
"Execution failed after %d retries - Command: %s - Error: %s: %s - Task: %s, Backend: %s",
|
||||
max_retries,
|
||||
command[:200],
|
||||
type(e).__name__,
|
||||
e,
|
||||
effective_task_id,
|
||||
env_type,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Command execution failed: {type(e).__name__}: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Got a result
|
||||
break
|
||||
|
||||
|
||||
# Extract output
|
||||
output = result.get("output", "")
|
||||
returncode = result.get("returncode", 0)
|
||||
|
||||
|
||||
# Add helpful message for sudo failures in messaging context
|
||||
output = _handle_sudo_failure(output, env_type)
|
||||
|
||||
|
||||
# Truncate output if too long, keeping both head and tail
|
||||
MAX_OUTPUT_CHARS = 50000
|
||||
if len(output) > MAX_OUTPUT_CHARS:
|
||||
|
|
@ -1045,65 +1117,56 @@ def terminal_tool(
|
|||
tail_chars = MAX_OUTPUT_CHARS - head_chars # 60% tail (most recent/relevant output)
|
||||
omitted = len(output) - head_chars - tail_chars
|
||||
truncated_notice = (
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted "
|
||||
f"out of {len(output)} total] ...\n\n"
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted} chars omitted out of {len(output)} total] ...\n\n"
|
||||
)
|
||||
output = output[:head_chars] + truncated_notice + output[-tail_chars:]
|
||||
|
||||
# Redact secrets from command output (catches env/printenv leaking keys)
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
output = redact_sensitive_text(output.strip()) if output else ""
|
||||
|
||||
return json.dumps({
|
||||
"output": output,
|
||||
"exit_code": returncode,
|
||||
"error": None
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps({"output": output, "exit_code": returncode, "error": None}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Failed to execute command: {str(e)}",
|
||||
"status": "error"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"output": "", "exit_code": -1, "error": f"Failed to execute command: {str(e)}", "status": "error"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def check_terminal_requirements() -> bool:
|
||||
"""Check if all requirements for the terminal tool are met."""
|
||||
config = _get_env_config()
|
||||
env_type = config["env_type"]
|
||||
|
||||
|
||||
try:
|
||||
if env_type == "local":
|
||||
from minisweagent.environments.local import LocalEnvironment
|
||||
return True
|
||||
elif env_type == "docker":
|
||||
from minisweagent.environments.docker import DockerEnvironment
|
||||
# Check if docker is available
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(["docker", "version"], capture_output=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
elif env_type == "singularity":
|
||||
from minisweagent.environments.singularity import SingularityEnvironment
|
||||
import shutil
|
||||
|
||||
# Check if singularity/apptainer is available
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
executable = shutil.which("apptainer") or shutil.which("singularity")
|
||||
if executable:
|
||||
result = subprocess.run([executable, "--version"], capture_output=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
return False
|
||||
elif env_type == "ssh":
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
# Check that host and user are configured
|
||||
return bool(config.get("ssh_host")) and bool(config.get("ssh_user"))
|
||||
elif env_type == "modal":
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
# Check for modal token
|
||||
return os.getenv("MODAL_TOKEN_ID") is not None or Path.home().joinpath(".modal.toml").exists()
|
||||
elif env_type == "daytona":
|
||||
from daytona import Daytona
|
||||
return os.getenv("DAYTONA_API_KEY") is not None
|
||||
else:
|
||||
return False
|
||||
|
|
@ -1116,9 +1179,9 @@ if __name__ == "__main__":
|
|||
# Simple test when run directly
|
||||
print("Terminal Tool Module (mini-swe-agent backend)")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
config = _get_env_config()
|
||||
print(f"\nCurrent Configuration:")
|
||||
print("\nCurrent Configuration:")
|
||||
print(f" Environment type: {config['env_type']}")
|
||||
print(f" Docker image: {config['docker_image']}")
|
||||
print(f" Modal image: {config['modal_image']}")
|
||||
|
|
@ -1165,37 +1228,34 @@ TERMINAL_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The command to execute on the VM"
|
||||
},
|
||||
"command": {"type": "string", "description": "The command to execute on the VM"},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "ONLY for servers/watchers that never exit. For scripts, builds, installs — use foreground with timeout instead (it returns instantly when done).",
|
||||
"default": False
|
||||
"default": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait (default: 180). Returns INSTANTLY when command finishes — set high for long tasks, you won't wait unnecessarily.",
|
||||
"minimum": 1
|
||||
"minimum": 1,
|
||||
},
|
||||
"workdir": {
|
||||
"type": "string",
|
||||
"description": "Working directory for this command (absolute path). Defaults to the session working directory."
|
||||
"description": "Working directory for this command (absolute path). Defaults to the session working directory.",
|
||||
},
|
||||
"check_interval": {
|
||||
"type": "integer",
|
||||
"description": "Seconds between automatic status checks for background processes (gateway/messaging only, minimum 30). When set, I'll proactively report progress.",
|
||||
"minimum": 30
|
||||
"minimum": 30,
|
||||
},
|
||||
"pty": {
|
||||
"type": "boolean",
|
||||
"description": "Run in pseudo-terminal (PTY) mode for interactive CLI tools like Codex, Claude Code, or Python REPL. Only works with local and SSH backends. Default: false.",
|
||||
"default": False
|
||||
}
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
"required": ["command"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ Design:
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Valid status values for todo items
|
||||
VALID_STATUSES = {"pending", "in_progress", "completed", "cancelled"}
|
||||
|
|
@ -33,9 +32,9 @@ class TodoStore:
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._items: List[Dict[str, str]] = []
|
||||
self._items: list[dict[str, str]] = []
|
||||
|
||||
def write(self, todos: List[Dict[str, Any]], merge: bool = False) -> List[Dict[str, str]]:
|
||||
def write(self, todos: list[dict[str, Any]], merge: bool = False) -> list[dict[str, str]]:
|
||||
"""
|
||||
Write todos. Returns the full current list after writing.
|
||||
|
||||
|
|
@ -79,7 +78,7 @@ class TodoStore:
|
|||
self._items = rebuilt
|
||||
return self.read()
|
||||
|
||||
def read(self) -> List[Dict[str, str]]:
|
||||
def read(self) -> list[dict[str, str]]:
|
||||
"""Return a copy of the current list."""
|
||||
return [item.copy() for item in self._items]
|
||||
|
||||
|
|
@ -87,7 +86,7 @@ class TodoStore:
|
|||
"""Check if there are any items in the list."""
|
||||
return len(self._items) > 0
|
||||
|
||||
def format_for_injection(self) -> Optional[str]:
|
||||
def format_for_injection(self) -> str | None:
|
||||
"""
|
||||
Render the todo list for post-compression injection.
|
||||
|
||||
|
|
@ -113,7 +112,7 @@ class TodoStore:
|
|||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _validate(item: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _validate(item: dict[str, Any]) -> dict[str, str]:
|
||||
"""
|
||||
Validate and normalize a todo item.
|
||||
|
||||
|
|
@ -136,9 +135,9 @@ class TodoStore:
|
|||
|
||||
|
||||
def todo_tool(
|
||||
todos: Optional[List[Dict[str, Any]]] = None,
|
||||
todos: list[dict[str, Any]] | None = None,
|
||||
merge: bool = False,
|
||||
store: Optional[TodoStore] = None,
|
||||
store: TodoStore | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the todo tool. Reads or writes depending on params.
|
||||
|
|
@ -165,16 +164,19 @@ def todo_tool(
|
|||
completed = sum(1 for i in items if i["status"] == "completed")
|
||||
cancelled = sum(1 for i in items if i["status"] == "cancelled")
|
||||
|
||||
return json.dumps({
|
||||
"todos": items,
|
||||
"summary": {
|
||||
"total": len(items),
|
||||
"pending": pending,
|
||||
"in_progress": in_progress,
|
||||
"completed": completed,
|
||||
"cancelled": cancelled,
|
||||
return json.dumps(
|
||||
{
|
||||
"todos": items,
|
||||
"summary": {
|
||||
"total": len(items),
|
||||
"pending": pending,
|
||||
"in_progress": in_progress,
|
||||
"completed": completed,
|
||||
"cancelled": cancelled,
|
||||
},
|
||||
},
|
||||
}, ensure_ascii=False)
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def check_todo_requirements() -> bool:
|
||||
|
|
@ -214,34 +216,27 @@ TODO_SCHEMA = {
|
|||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique item identifier"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Task description"
|
||||
},
|
||||
"id": {"type": "string", "description": "Unique item identifier"},
|
||||
"content": {"type": "string", "description": "Task description"},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed", "cancelled"],
|
||||
"description": "Current status"
|
||||
}
|
||||
"description": "Current status",
|
||||
},
|
||||
},
|
||||
"required": ["id", "content", "status"]
|
||||
}
|
||||
"required": ["id", "content", "status"],
|
||||
},
|
||||
},
|
||||
"merge": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"true: update existing items by id, add new ones. "
|
||||
"false (default): replace the entire list."
|
||||
"true: update existing items by id, add new ones. false (default): replace the entire list."
|
||||
),
|
||||
"default": False
|
||||
}
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -253,6 +248,7 @@ registry.register(
|
|||
toolset="todo",
|
||||
schema=TODO_SCHEMA,
|
||||
handler=lambda args, **kw: todo_tool(
|
||||
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")),
|
||||
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")
|
||||
),
|
||||
check_fn=check_todo_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ Usage:
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm",
|
|||
MAX_FILE_SIZE = 25 * 1024 * 1024
|
||||
|
||||
|
||||
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
def transcribe_audio(file_path: str, model: str | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Transcribe an audio file using OpenAI's Whisper API.
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
}
|
||||
|
||||
audio_path = Path(file_path)
|
||||
|
||||
|
||||
# Validate file exists
|
||||
if not audio_path.exists():
|
||||
return {
|
||||
|
|
@ -73,14 +73,14 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
"transcript": "",
|
||||
"error": f"Audio file not found: {file_path}",
|
||||
}
|
||||
|
||||
|
||||
if not audio_path.is_file():
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Path is not a file: {file_path}",
|
||||
}
|
||||
|
||||
|
||||
# Validate file extension
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
|
|
@ -88,7 +88,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
"transcript": "",
|
||||
"error": f"Unsupported file format: {audio_path.suffix}. Supported formats: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
|
||||
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
|
|
@ -96,7 +96,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024)}MB)",
|
||||
"error": f"File too large: {file_size / (1024 * 1024):.1f}MB (max {MAX_FILE_SIZE / (1024 * 1024)}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
logger.error("Failed to get file size for %s: %s", file_path, e, exc_info=True)
|
||||
|
|
@ -111,7 +111,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
model = DEFAULT_STT_MODEL
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
|
||||
|
|
|
|||
|
|
@ -27,9 +27,8 @@ import logging
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,12 +37,14 @@ logger = logging.getLogger(__name__)
|
|||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import edge_tts
|
||||
|
||||
_HAS_EDGE_TTS = True
|
||||
except ImportError:
|
||||
_HAS_EDGE_TTS = False
|
||||
|
||||
try:
|
||||
from elevenlabs.client import ElevenLabs
|
||||
|
||||
_HAS_ELEVENLABS = True
|
||||
except ImportError:
|
||||
_HAS_ELEVENLABS = False
|
||||
|
|
@ -51,6 +52,7 @@ except ImportError:
|
|||
# openai is a core dependency, but guard anyway
|
||||
try:
|
||||
from openai import OpenAI as OpenAIClient
|
||||
|
||||
_HAS_OPENAI = True
|
||||
except ImportError:
|
||||
_HAS_OPENAI = False
|
||||
|
|
@ -72,7 +74,7 @@ MAX_TEXT_LENGTH = 4000
|
|||
# ===========================================================================
|
||||
# Config loader -- reads tts: section from ~/.hermes/config.yaml
|
||||
# ===========================================================================
|
||||
def _load_tts_config() -> Dict[str, Any]:
|
||||
def _load_tts_config() -> dict[str, Any]:
|
||||
"""
|
||||
Load TTS configuration from ~/.hermes/config.yaml.
|
||||
|
||||
|
|
@ -81,13 +83,14 @@ def _load_tts_config() -> Dict[str, Any]:
|
|||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
return config.get("tts", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_provider(tts_config: Dict[str, Any]) -> str:
|
||||
def _get_provider(tts_config: dict[str, Any]) -> str:
|
||||
"""Get the configured TTS provider name."""
|
||||
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
|
||||
|
||||
|
|
@ -100,7 +103,7 @@ def _has_ffmpeg() -> bool:
|
|||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
def _convert_to_opus(mp3_path: str) -> Optional[str]:
|
||||
def _convert_to_opus(mp3_path: str) -> str | None:
|
||||
"""
|
||||
Convert an MP3 file to OGG Opus format for Telegram voice bubbles.
|
||||
|
||||
|
|
@ -116,9 +119,9 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]:
|
|||
ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg"
|
||||
try:
|
||||
subprocess.run(
|
||||
["ffmpeg", "-i", mp3_path, "-acodec", "libopus",
|
||||
"-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
|
||||
capture_output=True, timeout=30,
|
||||
["ffmpeg", "-i", mp3_path, "-acodec", "libopus", "-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
)
|
||||
if os.path.exists(ogg_path) and os.path.getsize(ogg_path) > 0:
|
||||
return ogg_path
|
||||
|
|
@ -130,7 +133,7 @@ def _convert_to_opus(mp3_path: str) -> Optional[str]:
|
|||
# ===========================================================================
|
||||
# Provider: Edge TTS (free)
|
||||
# ===========================================================================
|
||||
async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
async def _generate_edge_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using Edge TTS.
|
||||
|
||||
|
|
@ -153,7 +156,7 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str,
|
|||
# ===========================================================================
|
||||
# Provider: ElevenLabs (premium)
|
||||
# ===========================================================================
|
||||
def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
def _generate_elevenlabs(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using ElevenLabs.
|
||||
|
||||
|
|
@ -198,7 +201,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]
|
|||
# ===========================================================================
|
||||
# Provider: OpenAI TTS
|
||||
# ===========================================================================
|
||||
def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
def _generate_openai_tts(text: str, output_path: str, tts_config: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using OpenAI TTS.
|
||||
|
||||
|
|
@ -241,7 +244,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
|
|||
# ===========================================================================
|
||||
def text_to_speech_tool(
|
||||
text: str,
|
||||
output_path: Optional[str] = None,
|
||||
output_path: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Convert text to speech audio.
|
||||
|
|
@ -276,7 +279,7 @@ def text_to_speech_tool(
|
|||
# produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3
|
||||
# and needs ffmpeg for conversion.
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower()
|
||||
want_opus = (platform == "telegram")
|
||||
want_opus = platform == "telegram"
|
||||
|
||||
# Determine output path
|
||||
if output_path:
|
||||
|
|
@ -300,47 +303,48 @@ def text_to_speech_tool(
|
|||
# Generate audio with the configured provider
|
||||
if provider == "elevenlabs":
|
||||
if not _HAS_ELEVENLABS:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
logger.info("Generating speech with ElevenLabs...")
|
||||
_generate_elevenlabs(text, file_str, tts_config)
|
||||
|
||||
elif provider == "openai":
|
||||
if not _HAS_OPENAI:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "OpenAI provider selected but 'openai' package not installed."
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "OpenAI provider selected but 'openai' package not installed."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
logger.info("Generating speech with OpenAI TTS...")
|
||||
_generate_openai_tts(text, file_str, tts_config)
|
||||
|
||||
else:
|
||||
# Default: Edge TTS (free)
|
||||
if not _HAS_EDGE_TTS:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Edge TTS not available. Run: pip install edge-tts"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "Edge TTS not available. Run: pip install edge-tts"}, ensure_ascii=False
|
||||
)
|
||||
logger.info("Generating speech with Edge TTS...")
|
||||
# Edge TTS is async, run it
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
pool.submit(
|
||||
lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))
|
||||
).result(timeout=60)
|
||||
pool.submit(lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))).result(timeout=60)
|
||||
except RuntimeError:
|
||||
asyncio.run(_generate_edge_tts(text, file_str, tts_config))
|
||||
|
||||
# Check the file was actually created
|
||||
if not os.path.exists(file_str) or os.path.getsize(file_str) == 0:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"TTS generation produced no output (provider: {provider})"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"TTS generation produced no output (provider: {provider})"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Try Opus conversion for Telegram compatibility (Edge TTS only outputs MP3)
|
||||
voice_compatible = False
|
||||
|
|
@ -361,13 +365,16 @@ def text_to_speech_tool(
|
|||
if voice_compatible:
|
||||
media_tag = f"[[audio_as_voice]]\n{media_tag}"
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"file_path": file_str,
|
||||
"media_tag": media_tag,
|
||||
"provider": provider,
|
||||
"voice_compatible": voice_compatible,
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"file_path": file_str,
|
||||
"media_tag": media_tag,
|
||||
"provider": provider,
|
||||
"voice_compatible": voice_compatible,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"TTS generation failed ({provider}): {e}"
|
||||
|
|
@ -404,7 +411,7 @@ if __name__ == "__main__":
|
|||
print("🔊 Text-to-Speech Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"\nProvider availability:")
|
||||
print("\nProvider availability:")
|
||||
print(f" Edge TTS: {'✅ installed' if _HAS_EDGE_TTS else '❌ not installed (pip install edge-tts)'}")
|
||||
print(f" ElevenLabs: {'✅ installed' if _HAS_ELEVENLABS else '❌ not installed (pip install elevenlabs)'}")
|
||||
print(f" API Key: {'✅ set' if os.getenv('ELEVENLABS_API_KEY') else '❌ not set'}")
|
||||
|
|
@ -429,25 +436,20 @@ TTS_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to convert to speech. Keep under 4000 characters."
|
||||
},
|
||||
"text": {"type": "string", "description": "The text to convert to speech. Keep under 4000 characters."},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
|
||||
}
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3",
|
||||
},
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name="text_to_speech",
|
||||
toolset="tts",
|
||||
schema=TTS_SCHEMA,
|
||||
handler=lambda args, **kw: text_to_speech_tool(
|
||||
text=args.get("text", ""),
|
||||
output_path=args.get("output_path")),
|
||||
handler=lambda args, **kw: text_to_speech_tool(text=args.get("text", ""), output_path=args.get("output_path")),
|
||||
check_fn=check_tts_requirements,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Features:
|
|||
Usage:
|
||||
from vision_tools import vision_analyze_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Analyze an image
|
||||
result = await vision_analyze_tool(
|
||||
image_url="https://example.com/image.jpg",
|
||||
|
|
@ -33,11 +33,14 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from agent.auxiliary_client import get_vision_auxiliary_client
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
|
|
@ -55,7 +58,7 @@ if _aux_sync_client is not None:
|
|||
_async_kwargs["default_headers"] = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
_aux_async_client = AsyncOpenAI(**_async_kwargs)
|
||||
|
||||
|
|
@ -65,10 +68,10 @@ _debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG")
|
|||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
|
|
@ -91,23 +94,22 @@ def _validate_image_url(url: str) -> bool:
|
|||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
"""
|
||||
Download an image from a URL to a local destination (async) with retry logic.
|
||||
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to download
|
||||
destination (Path): The path where the image should be saved
|
||||
max_retries (int): Maximum number of retry attempts (default: 3)
|
||||
|
||||
|
||||
Returns:
|
||||
Path: The path to the downloaded image
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If download fails after all retries
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
# Create parent directories if they don't exist
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
|
|
@ -122,10 +124,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
|||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# Save the image content
|
||||
destination.write_bytes(response.content)
|
||||
|
||||
|
||||
return destination
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
|
@ -141,56 +143,56 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
|||
str(e)[:100],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
raise last_error
|
||||
|
||||
|
||||
def _determine_mime_type(image_path: Path) -> str:
|
||||
"""
|
||||
Determine the MIME type of an image based on its file extension.
|
||||
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
|
||||
|
||||
Returns:
|
||||
str: The MIME type (defaults to image/jpeg if unknown)
|
||||
"""
|
||||
extension = image_path.suffix.lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.bmp': 'image/bmp',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml'
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
".bmp": "image/bmp",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
}
|
||||
return mime_types.get(extension, 'image/jpeg')
|
||||
return mime_types.get(extension, "image/jpeg")
|
||||
|
||||
|
||||
def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str:
|
||||
def _image_to_base64_data_url(image_path: Path, mime_type: str | None = None) -> str:
|
||||
"""
|
||||
Convert an image file to a base64-encoded data URL.
|
||||
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
mime_type (Optional[str]): MIME type of the image (auto-detected if None)
|
||||
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...")
|
||||
"""
|
||||
# Read the image as bytes
|
||||
data = image_path.read_bytes()
|
||||
|
||||
|
||||
# Encode to base64
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
|
||||
|
||||
# Determine MIME type
|
||||
mime = mime_type or _determine_mime_type(image_path)
|
||||
|
||||
|
||||
# Create data URL
|
||||
data_url = f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
|
|
@ -201,31 +203,31 @@ async def vision_analyze_tool(
|
|||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL or local file path using vision AI.
|
||||
|
||||
|
||||
This tool accepts either an HTTP/HTTPS URL or a local file path. For URLs,
|
||||
it downloads the image first. In both cases, the image is converted to base64
|
||||
and processed using Gemini 3 Flash Preview via OpenRouter API.
|
||||
|
||||
|
||||
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||
function (typically model_tools.py) to include both full description
|
||||
requests and specific questions.
|
||||
|
||||
|
||||
Args:
|
||||
image_url (str): The URL or local file path of the image to analyze.
|
||||
Accepts http://, https:// URLs or absolute/relative file paths.
|
||||
user_prompt (str): The pre-formatted prompt for the vision model
|
||||
model (str): The vision model to use (default: google/gemini-3-flash-preview)
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the analysis results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"analysis": str (defaults to error message if None)
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If download fails, analysis fails, or API key is not set
|
||||
|
||||
|
||||
Note:
|
||||
- For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up
|
||||
- For local file paths, the file is used directly and NOT deleted
|
||||
|
|
@ -235,36 +237,41 @@ async def vision_analyze_tool(
|
|||
"parameters": {
|
||||
"image_url": image_url,
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
"model": model
|
||||
"model": model,
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"analysis_length": 0,
|
||||
"model_used": model,
|
||||
"image_size_bytes": 0
|
||||
"image_size_bytes": 0,
|
||||
}
|
||||
|
||||
|
||||
temp_image_path = None
|
||||
# Track whether we should clean up the file after processing.
|
||||
# Local files (e.g. from the image cache) should NOT be deleted.
|
||||
should_cleanup = True
|
||||
|
||||
|
||||
try:
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
if is_interrupted():
|
||||
return json.dumps({"success": False, "error": "Interrupted"})
|
||||
|
||||
logger.info("Analyzing image: %s", image_url[:60])
|
||||
logger.info("User prompt: %s", user_prompt[:100])
|
||||
|
||||
|
||||
# Check auxiliary vision client availability
|
||||
if _aux_async_client is None or DEFAULT_VISION_MODEL is None:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"analysis": "Vision analysis unavailable: no auxiliary vision model configured. "
|
||||
"Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools."
|
||||
}, indent=2, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"analysis": "Vision analysis unavailable: no auxiliary vision model configured. "
|
||||
"Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.",
|
||||
},
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Determine if this is a local file path or a remote URL
|
||||
local_path = Path(image_url)
|
||||
if local_path.is_file():
|
||||
|
|
@ -280,50 +287,41 @@ async def vision_analyze_tool(
|
|||
await _download_image(image_url, temp_image_path)
|
||||
should_cleanup = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path."
|
||||
)
|
||||
|
||||
raise ValueError("Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path.")
|
||||
|
||||
# Get image file size for logging
|
||||
image_size_bytes = temp_image_path.stat().st_size
|
||||
image_size_kb = image_size_bytes / 1024
|
||||
logger.info("Image ready (%.1f KB)", image_size_kb)
|
||||
|
||||
|
||||
# Convert image to base64 data URL
|
||||
logger.info("Converting image to base64...")
|
||||
image_data_url = _image_to_base64_data_url(temp_image_path)
|
||||
# Calculate size in KB for better readability
|
||||
data_size_kb = len(image_data_url) / 1024
|
||||
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
|
||||
|
||||
|
||||
debug_call_data["image_size_bytes"] = image_size_bytes
|
||||
|
||||
|
||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||
comprehensive_prompt = user_prompt
|
||||
|
||||
|
||||
# Prepare the message with base64-encoded image
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comprehensive_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_data_url
|
||||
}
|
||||
}
|
||||
]
|
||||
{"type": "text", "text": comprehensive_prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
logger.info("Processing image with %s...", model)
|
||||
|
||||
|
||||
# Call the vision API
|
||||
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
|
||||
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
|
||||
|
||||
_extra = get_auxiliary_extra_body()
|
||||
response = await _aux_async_client.chat.completions.create(
|
||||
model=model,
|
||||
|
|
@ -332,44 +330,44 @@ async def vision_analyze_tool(
|
|||
**auxiliary_max_tokens_param(2000),
|
||||
**({} if not _extra else {"extra_body": _extra}),
|
||||
)
|
||||
|
||||
|
||||
# Extract the analysis
|
||||
analysis = response.choices[0].message.content.strip()
|
||||
analysis_length = len(analysis)
|
||||
|
||||
|
||||
logger.info("Image analysis completed (%s characters)", analysis_length)
|
||||
|
||||
|
||||
# Prepare successful response
|
||||
result = {
|
||||
"success": True,
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed.",
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["analysis_length"] = analysis_length
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("vision_analyze_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
"success": False,
|
||||
"analysis": "There was a problem with the request and the image could not be analyzed."
|
||||
"analysis": "There was a problem with the request and the image could not be analyzed.",
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
_debug.log_call("vision_analyze_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up temporary image file (but NOT local/cached files)
|
||||
if should_cleanup and temp_image_path and temp_image_path.exists():
|
||||
|
|
@ -377,9 +375,7 @@ async def vision_analyze_tool(
|
|||
temp_image_path.unlink()
|
||||
logger.debug("Cleaned up temporary image file")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
"Could not delete temporary file: %s", cleanup_error, exc_info=True
|
||||
)
|
||||
logger.warning("Could not delete temporary file: %s", cleanup_error, exc_info=True)
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
|
|
@ -387,10 +383,10 @@ def check_vision_requirements() -> bool:
|
|||
return _aux_async_client is not None
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
|
|
@ -403,27 +399,27 @@ if __name__ == "__main__":
|
|||
"""
|
||||
print("👁️ Vision Tools Module")
|
||||
print("=" * 40)
|
||||
|
||||
|
||||
# Check if vision model is available
|
||||
api_available = check_vision_requirements()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ No auxiliary vision model available")
|
||||
print("Set OPENROUTER_API_KEY or configure Nous Portal to enable vision tools.")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"✅ Vision model available: {DEFAULT_VISION_MODEL}")
|
||||
|
||||
|
||||
print("🛠️ Vision tools ready for use!")
|
||||
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from vision_tools import vision_analyze_tool")
|
||||
print(" import asyncio")
|
||||
|
|
@ -435,14 +431,14 @@ if __name__ == "__main__":
|
|||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'What architectural style is this building?'")
|
||||
print(" - 'Describe the emotions and mood in this image'")
|
||||
print(" - 'What text can you read in this image?'")
|
||||
print(" - 'Identify any safety hazards visible'")
|
||||
print(" - 'What products or brands are shown?'")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export VISION_TOOLS_DEBUG=true")
|
||||
|
|
@ -461,30 +457,24 @@ VISION_ANALYZE_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "Image URL (http/https) or local file path to analyze."
|
||||
},
|
||||
"image_url": {"type": "string", "description": "Image URL (http/https) or local file path to analyze."},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question."
|
||||
}
|
||||
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question.",
|
||||
},
|
||||
},
|
||||
"required": ["image_url", "question"]
|
||||
}
|
||||
"required": ["image_url", "question"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
|
||||
def _handle_vision_analyze(args: dict[str, Any], **kw: Any) -> Awaitable[str]:
|
||||
image_url = args.get("image_url", "")
|
||||
question = args.get("question", "")
|
||||
full_prompt = (
|
||||
"Fully describe and explain everything about this image, then answer the "
|
||||
f"following question:\n\n{question}"
|
||||
f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
)
|
||||
model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip()
|
||||
or DEFAULT_VISION_MODEL
|
||||
or "google/gemini-3-flash-preview")
|
||||
model = os.getenv("AUXILIARY_VISION_MODEL", "").strip() or DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview"
|
||||
return vision_analyze_tool(image_url, full_prompt, model)
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue