Skip to content

Commit

Permalink
support local/package level and serializable python modules/macros. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
woods-chen committed May 2, 2024
1 parent 2e3c6fe commit 886d57d
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 29 deletions.
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class NodeType(StrEnum):
Documentation = "doc"
Source = "source"
Macro = "macro"
PythonModule = "python_module"
Exposure = "exposure"
Metric = "metric"
Group = "group"
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/artifacts/resources/v1/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MacroArgument(dbtClassMixin):
@dataclass
class Macro(BaseResource):
macro_sql: str
resource_type: Literal[NodeType.Macro]
resource_type: Literal[NodeType.PythonModule, NodeType.Macro]
depends_on: MacroDependsOn = field(default_factory=MacroDependsOn)
description: str = ""
meta: Dict[str, Any] = field(default_factory=dict)
Expand Down
24 changes: 17 additions & 7 deletions core/dbt/context/macros.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from types import ModuleType
from typing import Any, Dict, Iterable, Union, Optional, List, Iterator, Mapping, Set

from dbt.clients.jinja import MacroGenerator, MacroStack
from dbt.contracts.graph.nodes import Macro
from dbt.contracts.graph.nodes import (
Macro,
PythonModule,
)
from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME
from dbt.exceptions import DuplicateMacroNameError, PackageNotFoundForMacroError


FlatNamespace = Dict[str, MacroGenerator]
NamespaceMember = Union[FlatNamespace, MacroGenerator]
FlatNamespace = Dict[str, Union[MacroGenerator, ModuleType]]
NamespaceMember = Union[FlatNamespace, Union[MacroGenerator, ModuleType]]
FullNamespace = Dict[str, NamespaceMember]


Expand Down Expand Up @@ -66,7 +70,9 @@ def __getitem__(self, key: str) -> NamespaceMember:
return dct[key]
raise KeyError(key)

def get_from_package(self, package_name: Optional[str], name: str) -> Optional[MacroGenerator]:
def get_from_package(
self, package_name: Optional[str], name: str
) -> Optional[Union[MacroGenerator, ModuleType]]:
if package_name is None:
return self.get(name)
elif package_name == GLOBAL_PROJECT_NAME:
Expand Down Expand Up @@ -112,7 +118,7 @@ def _add_macro_to(
self,
hierarchy: Dict[str, FlatNamespace],
macro: Macro,
macro_func: MacroGenerator,
macro_func: Union[MacroGenerator, ModuleType],
):
if macro.package_name in hierarchy:
namespace = hierarchy[macro.package_name]
Expand All @@ -124,13 +130,17 @@ def _add_macro_to(
raise DuplicateMacroNameError(macro_func.macro, macro, macro.package_name)
hierarchy[macro.package_name][macro.name] = macro_func

def add_macro(self, macro: Macro, ctx: Dict[str, Any]) -> None:
def add_macro(self, macro: Union[PythonModule, Macro], ctx: Dict[str, Any]) -> None:
macro_name: str = macro.name

# MacroGenerator is in clients/jinja.py
# a MacroGenerator object is a callable object that will
# execute the MacroGenerator.__call__ function
macro_func: MacroGenerator = MacroGenerator(macro, ctx, self.node, self.thread_ctx)
macro_func: Union[MacroGenerator, ModuleType] = (
macro.module
if isinstance(macro, PythonModule)
else MacroGenerator(macro, ctx, self.node, self.thread_ctx)
)

# internal macros (from plugins) will be processed separately from
# project macros, so store them in a different place
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class ParseFileType(StrEnum):
Macro = "macro"
PythonModule = "python_module"
Model = "model"
Snapshot = "snapshot"
Analysis = "analysis"
Expand All @@ -27,6 +28,7 @@ class ParseFileType(StrEnum):

parse_file_type_to_parser = {
ParseFileType.Macro: "MacroParser",
ParseFileType.PythonModule: "PythonModuleParser",
ParseFileType.Model: "ModelParser",
ParseFileType.Snapshot: "SnapshotParser",
ParseFileType.Analysis: "AnalysisParser",
Expand Down
49 changes: 30 additions & 19 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
GraphMemberNode,
Group,
Macro,
PythonModule,
ManifestNode,
Metric,
ModelNode,
Expand Down Expand Up @@ -489,7 +490,11 @@ def build_node_edges(nodes: List[ManifestNode]):
# Build a map of children of macros and generic tests
def build_macro_edges(nodes: List[Any]):
forward_edges: Dict[str, List[str]] = {
n.unique_id: [] for n in nodes if n.unique_id.startswith("macro") or n.depends_on_macros
n.unique_id: []
for n in nodes
if n.unique_id.startswith("macro")
or n.unique_id.startswith("python_module")
or n.depends_on_macros
}
for node in nodes:
for unique_id in node.depends_on_macros:
Expand All @@ -511,7 +516,7 @@ class Locality(enum.IntEnum):
@dataclass
class MacroCandidate:
locality: Locality
macro: Macro
macro: Union[PythonModule, Macro]

def __eq__(self, other: object) -> bool:
if not isinstance(other, MacroCandidate):
Expand Down Expand Up @@ -591,12 +596,14 @@ def last_candidate(

return None

def last(self) -> Optional[Macro]:
def last(self) -> Optional[Union[PythonModule, Macro]]:
last_candidate = self.last_candidate()
return last_candidate.macro if last_candidate is not None else None


def _get_locality(macro: Macro, root_project_name: str, internal_packages: Set[str]) -> Locality:
def _get_locality(
macro: Union[PythonModule, Macro], root_project_name: str, internal_packages: Set[str]
) -> Locality:
if macro.package_name == root_project_name:
return Locality.Root
elif macro.package_name in internal_packages:
Expand Down Expand Up @@ -657,7 +664,7 @@ def __init__(self):

def find_macro_by_name(
self, name: str, root_project_name: str, package: Optional[str]
) -> Optional[Macro]:
) -> Optional[Union[PythonModule, Macro]]:
"""Find a macro in the graph by its name and package name, or None for
any package. The root project name is used to determine priority:
- locally defined macros come first
Expand All @@ -680,7 +687,7 @@ def filter(candidate: MacroCandidate) -> bool:

def find_generate_macro_by_name(
self, component: str, root_project_name: str, imported_package: Optional[str] = None
) -> Optional[Macro]:
) -> Optional[Union[PythonModule, Macro]]:
"""
The default `generate_X_name` macros are similar to regular ones, but only
includes imported packages when searching for a package.
Expand Down Expand Up @@ -738,7 +745,7 @@ def _find_macros_by_name(

return candidates

def get_macros_by_name(self) -> Dict[str, List[Macro]]:
def get_macros_by_name(self) -> Dict[str, List[Union[PythonModule, Macro]]]:
if self._macros_by_name is None:
# The by-name mapping doesn't exist yet (perhaps because the manifest
# was deserialized), so we build it.
Expand All @@ -747,11 +754,13 @@ def get_macros_by_name(self) -> Dict[str, List[Macro]]:
return self._macros_by_name

@staticmethod
def _build_macros_by_name(macros: Mapping[str, Macro]) -> Dict[str, List[Macro]]:
def _build_macros_by_name(
macros: Mapping[str, Union[PythonModule, Macro]]
) -> Dict[str, List[Union[PythonModule, Macro]]]:
# Convert a macro dictionary keyed on unique id to a flattened version
# keyed on macro name for faster lookup by name. Since macro names are
# not necessarily unique, the dict value is a list.
macros_by_name: Dict[str, List[Macro]] = {}
macros_by_name: Dict[str, List[Union[PythonModule, Macro]]] = {}
for macro in macros.values():
if macro.name not in macros_by_name:
macros_by_name[macro.name] = []
Expand All @@ -760,7 +769,7 @@ def _build_macros_by_name(macros: Mapping[str, Macro]) -> Dict[str, List[Macro]]

return macros_by_name

def get_macros_by_package(self) -> Dict[str, Dict[str, Macro]]:
def get_macros_by_package(self) -> Dict[str, Dict[str, Union[PythonModule, Macro]]]:
if self._macros_by_package is None:
# The by-package mapping doesn't exist yet (perhaps because the manifest
# was deserialized), so we build it.
Expand All @@ -769,10 +778,12 @@ def get_macros_by_package(self) -> Dict[str, Dict[str, Macro]]:
return self._macros_by_package

@staticmethod
def _build_macros_by_package(macros: Mapping[str, Macro]) -> Dict[str, Dict[str, Macro]]:
def _build_macros_by_package(
macros: Mapping[str, Union[PythonModule, Macro]]
) -> Dict[str, Dict[str, Union[PythonModule, Macro]]]:
# Convert a macro dictionary keyed on unique id to a flattened version
# keyed on package name for faster lookup by name.
macros_by_package: Dict[str, Dict[str, Macro]] = {}
macros_by_package: Dict[str, Dict[str, Union[PythonModule, Macro]]] = {}
for macro in macros.values():
if macro.package_name not in macros_by_package:
macros_by_package[macro.package_name] = {}
Expand Down Expand Up @@ -810,7 +821,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
# args tuple in the right position.
nodes: MutableMapping[str, ManifestNode] = field(default_factory=dict)
sources: MutableMapping[str, SourceDefinition] = field(default_factory=dict)
macros: MutableMapping[str, Macro] = field(default_factory=dict)
macros: MutableMapping[str, Union[PythonModule, Macro]] = field(default_factory=dict)
docs: MutableMapping[str, Documentation] = field(default_factory=dict)
exposures: MutableMapping[str, Exposure] = field(default_factory=dict)
metrics: MutableMapping[str, Metric] = field(default_factory=dict)
Expand Down Expand Up @@ -860,11 +871,11 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
default_factory=get_mp_context().Lock,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_macros_by_name: Optional[Dict[str, List[Macro]]] = field(
_macros_by_name: Optional[Dict[str, List[Union[PythonModule, Macro]]]] = field(
default=None,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
_macros_by_package: Optional[Dict[str, Dict[str, Macro]]] = field(
_macros_by_package: Optional[Dict[str, Dict[str, Union[PythonModule, Macro]]]] = field(
default=None,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
)
Expand Down Expand Up @@ -937,7 +948,7 @@ def _materialization_candidates_for(

def find_materialization_macro_by_name(
self, project_name: str, materialization_name: str, adapter_type: str
) -> Optional[Macro]:
) -> Optional[Union[PythonModule, Macro]]:
candidates: CandidateList = CandidateList(
chain.from_iterable(
self._materialization_candidates_for(
Expand Down Expand Up @@ -1511,7 +1522,7 @@ def merge_from_artifact(
fire_event(MergedFromState(num_merged=len(merged), sample=sample))

# Methods that were formerly in ParseResult
def add_macro(self, source_file: SourceFile, macro: Macro):
def add_macro(self, source_file: SourceFile, macro: Union[PythonModule, Macro]):
if macro.unique_id in self.macros:
# detect that the macro exists and emit an error
raise DuplicateMacroInPackageError(macro=macro, macro_mapping=self.macros)
Expand Down Expand Up @@ -1696,8 +1707,8 @@ def __init__(self, macros) -> None:
# This is returned by the 'graph' context property
# in the ProviderContext class.
self.flat_graph: Dict[str, Any] = {}
self._macros_by_name: Optional[Dict[str, List[Macro]]] = None
self._macros_by_package: Optional[Dict[str, Dict[str, Macro]]] = None
self._macros_by_name: Optional[Dict[str, List[Union[PythonModule, Macro]]]] = None
self._macros_by_package: Optional[Dict[str, Dict[str, Union[PythonModule, Macro]]]] = None


AnyManifest = Union[Manifest, MacroManifest]
Expand Down
88 changes: 88 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import importlib.util
from types import ModuleType
import os
from datetime import datetime
from dataclasses import dataclass, field
from functools import cached_property
import hashlib
from zipfile import ZipFile

from mashumaro.types import SerializableType
from typing import (
Expand Down Expand Up @@ -1042,6 +1046,90 @@ def depends_on_macros(self):
return self.depends_on.macros


# ====================================
# PythonModule
# ====================================


def load_python_module(name: str, path: str, files: List[str]) -> ModuleType:
files.append(path)
spec: Any = importlib.util.spec_from_file_location(name, path)
module: ModuleType = importlib.util.module_from_spec(spec)
# # relative importing in the python macros is not supported if modules are not
# # added to sys.modules.
# # this may increase the probability of module name conflicts.
# sys.modules[self.name] = module
spec.loader.exec_module(module)
f_init = "__init__.py"
if os.path.basename(path) == f_init:
dir_path = os.path.dirname(path)
for sub_obj in os.listdir(dir_path):
sub_obj_path = os.path.join(dir_path, sub_obj)
sub_obj_name, sub_obj_ext = os.path.splitext(sub_obj)
if (
os.path.isfile(sub_obj_path)
and sub_obj_ext == ".py"
and sub_obj != f_init
and sub_obj_name.isidentifier()
and sub_obj_name not in dir(module)
):
setattr(
module,
sub_obj_name,
load_python_module(name=sub_obj_name, path=sub_obj_path, files=files),
)
elif (
os.path.isdir(sub_obj_path)
and os.path.isfile(os.path.join(sub_obj_path, f_init))
and sub_obj_name.isidentifier()
and sub_obj_name not in dir(module)
):
setattr(
module,
sub_obj_name,
load_python_module(
name=sub_obj_name, path=os.path.join(sub_obj_path, f_init), files=files
),
)
return module


@dataclass
class PythonModule(Macro):
absolute_path: str = ""
relative_path: str = ""
resource_type: Literal[NodeType.PythonModule] = field(
metadata={"restrict": Literal[NodeType.PythonModule]}
)
files: List[str] = field(default_factory=list)

def __post_init__(self):
_ = self.module

@cached_property
def module(self) -> ModuleType:
module = load_python_module(name=self.name, path=self.absolute_path, files=self.files)
# # relative importing in the python macros is not supported if modules are not
# # added to sys.modules.
# # this may increase the probability of module name conflicts.
# sys.modules[self.name] = module
return module

def add_to_zipfile(
self,
zipfile: str,
mode: Literal["r", "w", "x", "a"] = "a",
is_in_root_package: bool = True,
):
len_root_path = len(self.absolute_path) - len(self.relative_path)
with ZipFile(zipfile, mode=mode) as f:
for file in self.files:
arcname = file[len_root_path:]
if not is_in_root_package:
arcname = os.path.join(self.package_name, arcname)
f.write(filename=file, arcname=arcname)


# ====================================
# Documentation node
# ====================================
Expand Down
Loading

0 comments on commit 886d57d

Please sign in to comment.