fix: tighten AST check to module-level only

The original tree-wide ast.walk() would match registry.register() calls
inside functions too. Restrict to top-level ast.Expr statements so helper
modules that call registry.register() inside a function are never picked
up as tool modules.
This commit is contained in:
Teknium 2026-04-14 20:51:55 -07:00 committed by Teknium
parent 4b2a1a4337
commit fc6cb5b970

View file

@ -25,26 +25,32 @@ from typing import Callable, Dict, List, Optional, Set
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_registry_register_call(node: ast.AST) -> bool:
"""Return True when *node* is a ``registry.register(...)`` call expression."""
if not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Call):
return False
func = node.value.func
return (
isinstance(func, ast.Attribute)
and func.attr == "register"
and isinstance(func.value, ast.Name)
and func.value.id == "registry"
)
def _module_registers_tools(module_path: Path) -> bool: def _module_registers_tools(module_path: Path) -> bool:
"""Return True when the module contains a direct ``registry.register(...)`` call.""" """Return True when the module contains a top-level ``registry.register(...)`` call.
Only inspects module-body statements so that helper modules which happen
to call ``registry.register()`` inside a function are not picked up.
"""
try: try:
source = module_path.read_text(encoding="utf-8") source = module_path.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(module_path)) tree = ast.parse(source, filename=str(module_path))
except (OSError, SyntaxError): except (OSError, SyntaxError):
return False return False
for node in ast.walk(tree): return any(_is_registry_register_call(stmt) for stmt in tree.body)
if not isinstance(node, ast.Call):
continue
func = node.func
if (
isinstance(func, ast.Attribute)
and func.attr == "register"
and isinstance(func.value, ast.Name)
and func.value.id == "registry"
):
return True
return False
def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]: def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]: