diff --git a/model_tools.py b/model_tools.py index 1924b2516..801255b79 100644 --- a/model_tools.py +++ b/model_tools.py @@ -26,7 +26,7 @@ import logging import threading from typing import Dict, Any, List, Optional, Tuple -from tools.registry import registry +from tools.registry import discover_builtin_tools, registry from toolsets import resolve_toolset, validate_toolset logger = logging.getLogger(__name__) @@ -129,45 +129,7 @@ def _run_async(coro): # Tool Discovery (importing each module triggers its registry.register calls) # ============================================================================= -def _discover_tools(): - """Import all tool modules to trigger their registry.register() calls. - - Wrapped in a function so import errors in optional tools (e.g., fal_client - not installed) don't prevent the rest from loading. - """ - _modules = [ - "tools.web_tools", - "tools.terminal_tool", - "tools.file_tools", - "tools.vision_tools", - "tools.mixture_of_agents_tool", - "tools.image_generation_tool", - "tools.skills_tool", - "tools.skill_manager_tool", - "tools.browser_tool", - "tools.cronjob_tools", - "tools.rl_training_tool", - "tools.tts_tool", - "tools.todo_tool", - "tools.memory_tool", - "tools.session_search_tool", - "tools.clarify_tool", - "tools.code_execution_tool", - "tools.delegate_tool", - "tools.process_registry", - "tools.send_message_tool", - # "tools.honcho_tools", # Removed — Honcho is now a memory provider plugin - "tools.homeassistant_tool", - ] - import importlib - for mod_name in _modules: - try: - importlib.import_module(mod_name) - except Exception as e: - logger.warning("Could not import tool module %s: %s", mod_name, e) - - -_discover_tools() +discover_builtin_tools() # MCP tool discovery (external MCP servers from config) try: diff --git a/tests/tools/test_registry.py b/tests/tools/test_registry.py index 6b2756886..85246bd76 100644 --- a/tests/tools/test_registry.py +++ b/tests/tools/test_registry.py @@ -2,8 +2,10 @@ import json import threading +from pathlib import Path +from unittest.mock import patch -from tools.registry import ToolRegistry +from tools.registry import ToolRegistry, discover_builtin_tools def _dummy_handler(args, **kwargs): @@ -286,6 +288,74 @@ class TestCheckFnExceptionHandling: assert any(u["name"] == "crashes" for u in unavailable) +class TestBuiltinDiscovery: + def test_matches_previous_manual_builtin_tool_set(self): + expected = { + "tools.browser_tool", + "tools.clarify_tool", + "tools.code_execution_tool", + "tools.cronjob_tools", + "tools.delegate_tool", + "tools.file_tools", + "tools.homeassistant_tool", + "tools.image_generation_tool", + "tools.memory_tool", + "tools.mixture_of_agents_tool", + "tools.process_registry", + "tools.rl_training_tool", + "tools.send_message_tool", + "tools.session_search_tool", + "tools.skill_manager_tool", + "tools.skills_tool", + "tools.terminal_tool", + "tools.todo_tool", + "tools.tts_tool", + "tools.vision_tools", + "tools.web_tools", + } + + with patch("tools.registry.importlib.import_module"): + imported = discover_builtin_tools(Path(__file__).resolve().parents[2] / "tools") + + assert set(imported) == expected + + def test_imports_only_self_registering_modules(self, tmp_path): + tools_dir = tmp_path / "tools" + tools_dir.mkdir() + (tools_dir / "__init__.py").write_text("", encoding="utf-8") + (tools_dir / "registry.py").write_text("", encoding="utf-8") + (tools_dir / "alpha.py").write_text( + "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n", + encoding="utf-8", + ) + (tools_dir / "beta.py").write_text("VALUE = 1\n", encoding="utf-8") + + with patch("tools.registry.importlib.import_module") as mock_import: + imported = discover_builtin_tools(tools_dir) + + assert imported == ["tools.alpha"] + mock_import.assert_called_once_with("tools.alpha") + + def test_skips_mcp_tool_even_if_it_registers(self, tmp_path): + tools_dir = tmp_path / "tools" + tools_dir.mkdir() + (tools_dir / "__init__.py").write_text("", encoding="utf-8") + (tools_dir / "mcp_tool.py").write_text( + "from tools.registry import registry\nregistry.register(name='mcp_alpha', toolset='mcp-test', schema={}, handler=lambda *_a, **_k: '{}')\n", + encoding="utf-8", + ) + (tools_dir / "alpha.py").write_text( + "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n", + encoding="utf-8", + ) + + with patch("tools.registry.importlib.import_module") as mock_import: + imported = discover_builtin_tools(tools_dir) + + assert imported == ["tools.alpha"] + mock_import.assert_called_once_with("tools.alpha") + + class TestEmojiMetadata: """Verify per-tool emoji registration and lookup.""" diff --git a/tools/registry.py b/tools/registry.py index ebda77807..53939047b 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -14,14 +14,59 @@ Import chain (circular-import safe): run_agent.py, cli.py, batch_runner.py, etc. """ +import ast +import importlib import json import logging import threading +from pathlib import Path from typing import Callable, Dict, List, Optional, Set logger = logging.getLogger(__name__) +def _module_registers_tools(module_path: Path) -> bool: + """Return True when the module contains a direct ``registry.register(...)`` call.""" + try: + source = module_path.read_text(encoding="utf-8") + tree = ast.parse(source, filename=str(module_path)) + except (OSError, SyntaxError): + return False + + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if ( + isinstance(func, ast.Attribute) + and func.attr == "register" + and isinstance(func.value, ast.Name) + and func.value.id == "registry" + ): + return True + return False + + +def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]: + """Import built-in self-registering tool modules and return their module names.""" + tools_path = Path(tools_dir) if tools_dir is not None else Path(__file__).resolve().parent + module_names = [ + f"tools.{path.stem}" + for path in sorted(tools_path.glob("*.py")) + if path.name not in {"__init__.py", "registry.py", "mcp_tool.py"} + and _module_registers_tools(path) + ] + + imported: List[str] = [] + for mod_name in module_names: + try: + importlib.import_module(mod_name) + imported.append(mod_name) + except Exception as e: + logger.warning("Could not import tool module %s: %s", mod_name, e) + return imported + + class ToolEntry: """Metadata for a single registered tool."""