From 886d57d66ae00ed8b9c069f9d696b2d6aa836b4e Mon Sep 17 00:00:00 2001 From: woods-chen Date: Thu, 2 May 2024 14:52:00 +0800 Subject: [PATCH] support local/package level and serializable python modules/macros. (#10078) --- core/dbt/artifacts/resources/types.py | 1 + core/dbt/artifacts/resources/v1/macro.py | 2 +- core/dbt/context/macros.py | 24 +++++-- core/dbt/contracts/files.py | 2 + core/dbt/contracts/graph/manifest.py | 49 ++++++++----- core/dbt/contracts/graph/nodes.py | 88 ++++++++++++++++++++++++ core/dbt/parser/macros.py | 70 ++++++++++++++++++- core/dbt/parser/manifest.py | 22 +++++- core/dbt/parser/partial.py | 1 + core/dbt/parser/read_files.py | 5 ++ tests/unit/test_node_types.py | 1 + 11 files changed, 236 insertions(+), 29 deletions(-) diff --git a/core/dbt/artifacts/resources/types.py b/core/dbt/artifacts/resources/types.py index c0ab5341e4c..e7ed05232b2 100644 --- a/core/dbt/artifacts/resources/types.py +++ b/core/dbt/artifacts/resources/types.py @@ -28,6 +28,7 @@ class NodeType(StrEnum): Documentation = "doc" Source = "source" Macro = "macro" + PythonModule = "python_module" Exposure = "exposure" Metric = "metric" Group = "group" diff --git a/core/dbt/artifacts/resources/v1/macro.py b/core/dbt/artifacts/resources/v1/macro.py index be02d529ee1..cee1c7e54a2 100644 --- a/core/dbt/artifacts/resources/v1/macro.py +++ b/core/dbt/artifacts/resources/v1/macro.py @@ -18,7 +18,7 @@ class MacroArgument(dbtClassMixin): @dataclass class Macro(BaseResource): macro_sql: str - resource_type: Literal[NodeType.Macro] + resource_type: Literal[NodeType.PythonModule, NodeType.Macro] depends_on: MacroDependsOn = field(default_factory=MacroDependsOn) description: str = "" meta: Dict[str, Any] = field(default_factory=dict) diff --git a/core/dbt/context/macros.py b/core/dbt/context/macros.py index c2442b1f4a8..f9697dd0227 100644 --- a/core/dbt/context/macros.py +++ b/core/dbt/context/macros.py @@ -1,13 +1,17 @@ +from types import ModuleType from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set from dbt.clients.jinja import MacroGenerator, MacroStack -from dbt.contracts.graph.nodes import Macro +from dbt.contracts.graph.nodes import ( + Macro, + PythonModule, +) from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME from dbt.exceptions import DuplicateMacroNameError, PackageNotFoundForMacroError -FlatNamespace = Dict[str, MacroGenerator] -NamespaceMember = Union[FlatNamespace, MacroGenerator] +FlatNamespace = Dict[str, Union[MacroGenerator, ModuleType]] +NamespaceMember = Union[FlatNamespace, Union[MacroGenerator, ModuleType]] FullNamespace = Dict[str, NamespaceMember] @@ -66,7 +70,9 @@ def __getitem__(self, key: str) -> NamespaceMember: return dct[key] raise KeyError(key) - def get_from_package(self, package_name: Optional[str], name: str) -> Optional[MacroGenerator]: + def get_from_package( + self, package_name: Optional[str], name: str + ) -> Optional[Union[MacroGenerator, ModuleType]]: if package_name is None: return self.get(name) elif package_name == GLOBAL_PROJECT_NAME: @@ -112,7 +118,7 @@ def _add_macro_to( self, hierarchy: Dict[str, FlatNamespace], macro: Macro, - macro_func: MacroGenerator, + macro_func: Union[MacroGenerator, ModuleType], ): if macro.package_name in hierarchy: namespace = hierarchy[macro.package_name] @@ -124,13 +130,17 @@ def _add_macro_to( raise DuplicateMacroNameError(macro_func.macro, macro, macro.package_name) hierarchy[macro.package_name][macro.name] = macro_func - def add_macro(self, macro: Macro, ctx: Dict[str, Any]) -> None: + def add_macro(self, macro: Union[PythonModule, Macro], ctx: Dict[str, Any]) -> None: macro_name: str = macro.name # MacroGenerator is in clients/jinja.py # a MacroGenerator object is a callable object that will # execute the MacroGenerator.__call__ function - macro_func: MacroGenerator = MacroGenerator(macro, ctx, self.node, self.thread_ctx) + macro_func: Union[MacroGenerator, ModuleType] = ( + macro.module + if isinstance(macro, PythonModule) + else MacroGenerator(macro, ctx, self.node, self.thread_ctx) + ) # internal macros (from plugins) will be processed separately from # project macros, so store them in a different place diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index 714782161cc..6588badb844 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -13,6 +13,7 @@ class ParseFileType(StrEnum): Macro = "macro" + PythonModule = "python_module" Model = "model" Snapshot = "snapshot" Analysis = "analysis" @@ -27,6 +28,7 @@ class ParseFileType(StrEnum): parse_file_type_to_parser = { ParseFileType.Macro: "MacroParser", + ParseFileType.PythonModule: "PythonModuleParser", ParseFileType.Model: "ModelParser", ParseFileType.Snapshot: "SnapshotParser", ParseFileType.Analysis: "AnalysisParser", diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index b7b8142d72e..bb390ca1edb 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -33,6 +33,7 @@ GraphMemberNode, Group, Macro, + PythonModule, ManifestNode, Metric, ModelNode, @@ -489,7 +490,11 @@ def build_node_edges(nodes: List[ManifestNode]): # Build a map of children of macros and generic tests def build_macro_edges(nodes: List[Any]): forward_edges: Dict[str, List[str]] = { - n.unique_id: [] for n in nodes if n.unique_id.startswith("macro") or n.depends_on_macros + n.unique_id: [] + for n in nodes + if n.unique_id.startswith("macro") + or n.unique_id.startswith("python_module") + or n.depends_on_macros } for node in nodes: for unique_id in node.depends_on_macros: @@ -511,7 +516,7 @@ class Locality(enum.IntEnum): @dataclass class MacroCandidate: locality: Locality - macro: Macro + macro: Union[PythonModule, Macro] def __eq__(self, other: object) -> bool: if not isinstance(other, MacroCandidate): @@ -591,12 +596,14 @@ def last_candidate( return None - def last(self) -> Optional[Macro]: + def last(self) -> Optional[Union[PythonModule, Macro]]: last_candidate = self.last_candidate() return last_candidate.macro if last_candidate is not None else None -def _get_locality(macro: Macro, root_project_name: str, internal_packages: Set[str]) -> Locality: +def _get_locality( + macro: Union[PythonModule, Macro], root_project_name: str, internal_packages: Set[str] +) -> Locality: if macro.package_name == root_project_name: return Locality.Root elif macro.package_name in internal_packages: @@ -657,7 +664,7 @@ def __init__(self): def find_macro_by_name( self, name: str, root_project_name: str, package: Optional[str] - ) -> Optional[Macro]: + ) -> Optional[Union[PythonModule, Macro]]: """Find a macro in the graph by its name and package name, or None for any package. The root project name is used to determine priority: - locally defined macros come first @@ -680,7 +687,7 @@ def filter(candidate: MacroCandidate) -> bool: def find_generate_macro_by_name( self, component: str, root_project_name: str, imported_package: Optional[str] = None - ) -> Optional[Macro]: + ) -> Optional[Union[PythonModule, Macro]]: """ The default `generate_X_name` macros are similar to regular ones, but only includes imported packages when searching for a package. @@ -738,7 +745,7 @@ def _find_macros_by_name( return candidates - def get_macros_by_name(self) -> Dict[str, List[Macro]]: + def get_macros_by_name(self) -> Dict[str, List[Union[PythonModule, Macro]]]: if self._macros_by_name is None: # The by-name mapping doesn't exist yet (perhaps because the manifest # was deserialized), so we build it. @@ -747,11 +754,13 @@ def get_macros_by_name(self) -> Dict[str, List[Macro]]: return self._macros_by_name @staticmethod - def _build_macros_by_name(macros: Mapping[str, Macro]) -> Dict[str, List[Macro]]: + def _build_macros_by_name( + macros: Mapping[str, Union[PythonModule, Macro]] + ) -> Dict[str, List[Union[PythonModule, Macro]]]: # Convert a macro dictionary keyed on unique id to a flattened version # keyed on macro name for faster lookup by name. Since macro names are # not necessarily unique, the dict value is a list. - macros_by_name: Dict[str, List[Macro]] = {} + macros_by_name: Dict[str, List[Union[PythonModule, Macro]]] = {} for macro in macros.values(): if macro.name not in macros_by_name: macros_by_name[macro.name] = [] @@ -760,7 +769,7 @@ def _build_macros_by_name(macros: Mapping[str, Macro]) -> Dict[str, List[Macro]] return macros_by_name - def get_macros_by_package(self) -> Dict[str, Dict[str, Macro]]: + def get_macros_by_package(self) -> Dict[str, Dict[str, Union[PythonModule, Macro]]]: if self._macros_by_package is None: # The by-package mapping doesn't exist yet (perhaps because the manifest # was deserialized), so we build it. @@ -769,10 +778,12 @@ def get_macros_by_package(self) -> Dict[str, Dict[str, Macro]]: return self._macros_by_package @staticmethod - def _build_macros_by_package(macros: Mapping[str, Macro]) -> Dict[str, Dict[str, Macro]]: + def _build_macros_by_package( + macros: Mapping[str, Union[PythonModule, Macro]] + ) -> Dict[str, Dict[str, Union[PythonModule, Macro]]]: # Convert a macro dictionary keyed on unique id to a flattened version # keyed on package name for faster lookup by name. - macros_by_package: Dict[str, Dict[str, Macro]] = {} + macros_by_package: Dict[str, Dict[str, Union[PythonModule, Macro]]] = {} for macro in macros.values(): if macro.package_name not in macros_by_package: macros_by_package[macro.package_name] = {} @@ -810,7 +821,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin): # args tuple in the right position. nodes: MutableMapping[str, ManifestNode] = field(default_factory=dict) sources: MutableMapping[str, SourceDefinition] = field(default_factory=dict) - macros: MutableMapping[str, Macro] = field(default_factory=dict) + macros: MutableMapping[str, Union[PythonModule, Macro]] = field(default_factory=dict) docs: MutableMapping[str, Documentation] = field(default_factory=dict) exposures: MutableMapping[str, Exposure] = field(default_factory=dict) metrics: MutableMapping[str, Metric] = field(default_factory=dict) @@ -860,11 +871,11 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin): default_factory=get_mp_context().Lock, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, ) - _macros_by_name: Optional[Dict[str, List[Macro]]] = field( + _macros_by_name: Optional[Dict[str, List[Union[PythonModule, Macro]]]] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, ) - _macros_by_package: Optional[Dict[str, Dict[str, Macro]]] = field( + _macros_by_package: Optional[Dict[str, Dict[str, Union[PythonModule, Macro]]]] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, ) @@ -937,7 +948,7 @@ def _materialization_candidates_for( def find_materialization_macro_by_name( self, project_name: str, materialization_name: str, adapter_type: str - ) -> Optional[Macro]: + ) -> Optional[Union[PythonModule, Macro]]: candidates: CandidateList = CandidateList( chain.from_iterable( self._materialization_candidates_for( @@ -1511,7 +1522,7 @@ def merge_from_artifact( fire_event(MergedFromState(num_merged=len(merged), sample=sample)) # Methods that were formerly in ParseResult - def add_macro(self, source_file: SourceFile, macro: Macro): + def add_macro(self, source_file: SourceFile, macro: Union[PythonModule, Macro]): if macro.unique_id in self.macros: # detect that the macro exists and emit an error raise DuplicateMacroInPackageError(macro=macro, macro_mapping=self.macros) @@ -1696,8 +1707,8 @@ def __init__(self, macros) -> None: # This is returned by the 'graph' context property # in the ProviderContext class. self.flat_graph: Dict[str, Any] = {} - self._macros_by_name: Optional[Dict[str, List[Macro]]] = None - self._macros_by_package: Optional[Dict[str, Dict[str, Macro]]] = None + self._macros_by_name: Optional[Dict[str, List[Union[PythonModule, Macro]]]] = None + self._macros_by_package: Optional[Dict[str, Dict[str, Union[PythonModule, Macro]]]] = None AnyManifest = Union[Manifest, MacroManifest] diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index e1f409ff1de..40af1263a15 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1,7 +1,11 @@ +import importlib.util +from types import ModuleType import os from datetime import datetime from dataclasses import dataclass, field +from functools import cached_property import hashlib +from zipfile import ZipFile from mashumaro.types import SerializableType from typing import ( @@ -1042,6 +1046,90 @@ def depends_on_macros(self): return self.depends_on.macros +# ==================================== +# PythonModule +# ==================================== + + +def load_python_module(name: str, path: str, files: List[str]) -> ModuleType: + files.append(path) + spec: Any = importlib.util.spec_from_file_location(name, path) + module: ModuleType = importlib.util.module_from_spec(spec) + # # relative importing in the python macros is not supported if modules are not + # # added to sys.modules. + # # this may increase the probability of module name conflicts. + # sys.modules[self.name] = module + spec.loader.exec_module(module) + f_init = "__init__.py" + if os.path.basename(path) == f_init: + dir_path = os.path.dirname(path) + for sub_obj in os.listdir(dir_path): + sub_obj_path = os.path.join(dir_path, sub_obj) + sub_obj_name, sub_obj_ext = os.path.splitext(sub_obj) + if ( + os.path.isfile(sub_obj_path) + and sub_obj_ext == ".py" + and sub_obj != f_init + and sub_obj_name.isidentifier() + and sub_obj_name not in dir(module) + ): + setattr( + module, + sub_obj_name, + load_python_module(name=sub_obj_name, path=sub_obj_path, files=files), + ) + elif ( + os.path.isdir(sub_obj_path) + and os.path.isfile(os.path.join(sub_obj_path, f_init)) + and sub_obj_name.isidentifier() + and sub_obj_name not in dir(module) + ): + setattr( + module, + sub_obj_name, + load_python_module( + name=sub_obj_name, path=os.path.join(sub_obj_path, f_init), files=files + ), + ) + return module + + +@dataclass +class PythonModule(Macro): + absolute_path: str = "" + relative_path: str = "" + resource_type: Literal[NodeType.PythonModule] = field( + metadata={"restrict": Literal[NodeType.PythonModule]} + ) + files: List[str] = field(default_factory=list) + + def __post_init__(self): + _ = self.module + + @cached_property + def module(self) -> ModuleType: + module = load_python_module(name=self.name, path=self.absolute_path, files=self.files) + # # relative importing in the python macros is not supported if modules are not + # # added to sys.modules. + # # this may increase the probability of module name conflicts. + # sys.modules[self.name] = module + return module + + def add_to_zipfile( + self, + zipfile: str, + mode: Literal["r", "w", "x", "a"] = "a", + is_in_root_package: bool = True, + ): + len_root_path = len(self.absolute_path) - len(self.relative_path) + with ZipFile(zipfile, mode=mode) as f: + for file in self.files: + arcname = file[len_root_path:] + if not is_in_root_package: + arcname = os.path.join(self.package_name, arcname) + f.write(filename=file, arcname=arcname) + + # ==================================== # Documentation node # ==================================== diff --git a/core/dbt/parser/macros.py b/core/dbt/parser/macros.py index 23a9bf53060..2d74db3814c 100644 --- a/core/dbt/parser/macros.py +++ b/core/dbt/parser/macros.py @@ -1,11 +1,15 @@ from typing import Iterable, List +import os import jinja2 from dbt_common.clients import jinja from dbt.clients.jinja import get_supported_languages from dbt.contracts.graph.unparsed import UnparsedMacro -from dbt.contracts.graph.nodes import Macro +from dbt.contracts.graph.nodes import ( + Macro, + PythonModule, +) from dbt.contracts.files import FilePath, SourceFile from dbt.exceptions import ParsingError from dbt.node_types import NodeType @@ -117,3 +121,67 @@ def parse_file(self, block: FileBlock): for node in self.parse_unparsed_macros(base_node): self.manifest.add_macro(block.file, node) + + +class PythonModuleParser(BaseParser[Macro]): + _module = None + + def get_paths(self) -> List[FilePath]: + return filesystem_search( + project=self.project, relative_dirs=self.project.macro_paths, extension=".py" + ) + + @property + def resource_type(self) -> NodeType: + return NodeType.PythonModule + + @classmethod + def get_compiled_path(cls, block: FileBlock): + return block.path.relative_path + + @classmethod + def get_module_name(cls, block: FileBlock): + relative_path = block.path.relative_path + paths = os.path.normpath(relative_path).split(os.sep) + file_name = paths[-1] + if file_name == "__init__.py": + paths.pop() + else: + file_name = ".".join(file_name.split(".")[:-1]) + paths[-1] = file_name + if not all(i.isidentifier() for i in paths): + raise ParsingError(f'invalid python module name: "{relative_path}"') + name = ".".join(paths) + return name + + def parse_file(self, block: FileBlock): + assert isinstance(block.file, SourceFile) + source_file: SourceFile = block.file + assert isinstance(source_file.contents, str) + name: str = self.get_module_name(block=block) + if not name: + return + # only parse top level python modules, the sub-modules will be loaded by + # their parent modules. + if len(name.split(".")) > 1: + return + # root_unique_id: str = self.generate_unique_id(name.split('.')[0]) + # if root_unique_id in self.manifest.macros: + # return + original_file_path: str = source_file.path.original_file_path + absolute_path: str = block.path.absolute_path + relative_path: str = block.path.relative_path + unique_id: str = self.generate_unique_id(name) + node = PythonModule( + path=original_file_path, + macro_sql=source_file.contents, + original_file_path=original_file_path, + package_name=self.project.project_name, + resource_type=NodeType.PythonModule, + name=name, + unique_id=unique_id, + absolute_path=absolute_path, + relative_path=relative_path, + files=[], + ) + self.manifest.add_macro(block.file, node) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 1d39fd21f6e..7e4c78dd2cd 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -126,7 +126,10 @@ from dbt.parser.docs import DocumentationParser from dbt.parser.fixtures import FixtureParser from dbt.parser.hooks import HookParser -from dbt.parser.macros import MacroParser +from dbt.parser.macros import ( + MacroParser, + PythonModuleParser, +) from dbt.parser.models import ModelParser from dbt.parser.schemas import SchemaParser from dbt.parser.search import FileBlock @@ -677,6 +680,13 @@ def load_and_parse_macros(self, project_parser_files): parser.parse_file(block) # increment parsed path count for performance tracking self._perf_info.parsed_path_count += 1 + if "PythonModuleParser" in parser_files: + parser = PythonModuleParser(project, self.manifest) + for file_id in parser_files["PythonModuleParser"]: + block = FileBlock(self.manifest.files[file_id]) + parser.parse_file(block) + # increment parsed path count for performance tracking + self._perf_info.parsed_path_count += 1 # generic tests hisotrically lived in the macros directoy but can now be nested # in a /generic directory under /tests so we want to process them here as well if "GenericTestParser" in parser_files: @@ -1071,6 +1081,16 @@ def create_macro_manifest(self): # This does not add the file to the manifest.files, # but that shouldn't be necessary here. macro_parser.parse_file(block) + # what is the manifest passed in actually used for? + python_module_parser = PythonModuleParser(project, self.manifest) + for path in python_module_parser.get_paths(): + source_file = load_source_file( + path, ParseFileType.PythonModule, project.project_name, {} + ) + block = FileBlock(source_file) + # This does not add the file to the manifest.files, + # but that shouldn't be necessary here. + python_module_parser.parse_file(block) macro_manifest = MacroManifest(self.manifest.macros) return macro_manifest diff --git a/core/dbt/parser/partial.py b/core/dbt/parser/partial.py index f9c558be6ba..96b468566bd 100644 --- a/core/dbt/parser/partial.py +++ b/core/dbt/parser/partial.py @@ -30,6 +30,7 @@ mg_files = ( ParseFileType.Macro, + ParseFileType.PythonModule, ParseFileType.GenericTest, ) diff --git a/core/dbt/parser/read_files.py b/core/dbt/parser/read_files.py index 314a2a0fdd1..c897c40c28f 100644 --- a/core/dbt/parser/read_files.py +++ b/core/dbt/parser/read_files.py @@ -389,6 +389,11 @@ def get_file_types_for_project(project): "extensions": [".sql"], "parser": "MacroParser", }, + ParseFileType.PythonModule: { + "paths": project.macro_paths, + "extensions": [".py"], + "parser": "PythonModuleParser", + }, ParseFileType.Model: { "paths": project.model_paths, "extensions": [".sql", ".py"], diff --git a/tests/unit/test_node_types.py b/tests/unit/test_node_types.py index 9611429a934..c103b2df0b0 100644 --- a/tests/unit/test_node_types.py +++ b/tests/unit/test_node_types.py @@ -13,6 +13,7 @@ NodeType.Documentation: "docs", NodeType.Source: "sources", NodeType.Macro: "macros", + NodeType.PythonModule: "python_modules", NodeType.Exposure: "exposures", NodeType.Metric: "metrics", NodeType.Group: "groups",