From fc6cb5b970f006dba448941ce5b3888fc36662fb Mon Sep 17 00:00:00 2001 From: Teknium Date: Tue, 14 Apr 2026 20:51:55 -0700 Subject: [PATCH] fix: tighten AST check to module-level only The original tree-wide ast.walk() would match registry.register() calls inside functions too. Restrict to top-level ast.Expr statements so helper modules that call registry.register() inside a function are never picked up as tool modules. --- tools/registry.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tools/registry.py b/tools/registry.py index 53939047b..e6d554e2b 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -25,26 +25,32 @@ from typing import Callable, Dict, List, Optional, Set logger = logging.getLogger(__name__) +def _is_registry_register_call(node: ast.AST) -> bool: + """Return True when *node* is a ``registry.register(...)`` call expression.""" + if not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Call): + return False + func = node.value.func + return ( + isinstance(func, ast.Attribute) + and func.attr == "register" + and isinstance(func.value, ast.Name) + and func.value.id == "registry" + ) + + def _module_registers_tools(module_path: Path) -> bool: - """Return True when the module contains a direct ``registry.register(...)`` call.""" + """Return True when the module contains a top-level ``registry.register(...)`` call. + + Only inspects module-body statements so that helper modules which happen + to call ``registry.register()`` inside a function are not picked up. + """ try: source = module_path.read_text(encoding="utf-8") tree = ast.parse(source, filename=str(module_path)) except (OSError, SyntaxError): return False - for node in ast.walk(tree): - if not isinstance(node, ast.Call): - continue - func = node.func - if ( - isinstance(func, ast.Attribute) - and func.attr == "register" - and isinstance(func.value, ast.Name) - and func.value.id == "registry" - ): - return True - return False + return any(_is_registry_register_call(stmt) for stmt in tree.body) def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]: