From 38aa9d5d8a76aa952011e0e749df0db697dfa155 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Thu, 19 Sep 2024 19:50:11 +0100 Subject: [PATCH] Fix #9005: Allow singular tests to be documented --- core/dbt/contracts/graph/manifest.py | 50 ++++++++++++++++++++++++++- core/dbt/contracts/graph/nodes.py | 5 +++ core/dbt/contracts/graph/unparsed.py | 5 +++ core/dbt/parser/common.py | 2 ++ core/dbt/parser/schemas.py | 51 +++++++++++++++++++++++++++- 5 files changed, 111 insertions(+), 2 deletions(-) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index f4cdafea737..b556b479fb4 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -58,6 +58,7 @@ SavedQuery, SeedNode, SemanticModel, + SingularTestNode, SourceDefinition, UnitTestDefinition, UnitTestFileFixture, @@ -89,7 +90,7 @@ RefName = str -def find_unique_id_for_package(storage, key, package: Optional[PackageName]): +def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]: if key not in storage: return None @@ -470,6 +471,43 @@ class AnalysisLookup(RefableLookup): _versioned_types: ClassVar[set] = set() +class SingularTestLookup(dbtClassMixin): + def __init__(self, manifest: "Manifest") -> None: + self.storage: Dict[str, Dict[PackageName, UniqueID]] = {} + self.populate(manifest) + + def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]: + return find_unique_id_for_package(self.storage, search_name, package) + + def find( + self, search_name, package: Optional[PackageName], manifest: "Manifest" + ) -> Optional[SingularTestNode]: + unique_id = self.get_unique_id(search_name, package) + if unique_id is not None: + return self.perform_lookup(unique_id, manifest) + return None + + def add_singular_test(self, source: SingularTestNode) -> None: + if source.search_name not in self.storage: + self.storage[source.search_name] = {} + + self.storage[source.search_name][source.package_name] = source.unique_id + + def populate(self, manifest: "Manifest") -> None: + for node in manifest.nodes.values(): + if isinstance(node, SingularTestNode): + self.add_singular_test(node) + + def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode: + if unique_id not in manifest.nodes: + raise dbt_common.exceptions.DbtInternalError( + f"Singular test {unique_id} found in cache but not found in manifest" + ) + node = manifest.nodes[unique_id] + assert isinstance(node, SingularTestNode) + return node + + def _packages_to_search( current_project: str, node_package: str, @@ -869,6 +907,9 @@ class Manifest(MacroMethods, dbtClassMixin): _analysis_lookup: Optional[AnalysisLookup] = field( default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} ) + _singular_test_lookup: Optional[SingularTestLookup] = field( + default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None} + ) _parsing_info: ParsingInfo = field( default_factory=ParsingInfo, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}, @@ -1264,6 +1305,12 @@ def analysis_lookup(self) -> AnalysisLookup: self._analysis_lookup = AnalysisLookup(self) return self._analysis_lookup + @property + def singular_test_lookup(self) -> SingularTestLookup: + if self._singular_test_lookup is None: + self._singular_test_lookup = SingularTestLookup(self) + return self._singular_test_lookup + @property def external_node_unique_ids(self): return [node.unique_id for node in self.nodes.values() if node.is_external_node] @@ -1708,6 +1755,7 @@ def __reduce_ex__(self, protocol): self._semantic_model_by_measure_lookup, self._disabled_lookup, self._analysis_lookup, + self._singular_test_lookup, ) return self.__class__, args diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index b28910c0de3..22699c74afc 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1627,6 +1627,11 @@ class ParsedMacroPatch(ParsedPatch): arguments: List[MacroArgument] = field(default_factory=list) +@dataclass +class ParsedSingularTestPatch(ParsedPatch): + pass + + # ==================================== # Node unions/categories # ==================================== diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 847be3d3a2a..ebe704fc1c5 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe access: Optional[str] = None +@dataclass +class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata): + pass + + @dataclass class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata): quote_columns: Optional[bool] = None diff --git a/core/dbt/parser/common.py b/core/dbt/parser/common.py index 3bafbb9550f..d3bb640a78f 100644 --- a/core/dbt/parser/common.py +++ b/core/dbt/parser/common.py @@ -13,6 +13,7 @@ UnparsedMacroUpdate, UnparsedModelUpdate, UnparsedNodeUpdate, + UnparsedSingularTestUpdate, ) from dbt.exceptions import ParsingError from dbt.parser.search import FileBlock @@ -38,6 +39,7 @@ def trimmed(inp: str) -> str: UnpatchedSourceDefinition, UnparsedExposure, UnparsedModelUpdate, + UnparsedSingularTestUpdate, ) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 3a06756e355..e4f2fb90b31 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -17,6 +17,7 @@ ModelNode, ParsedMacroPatch, ParsedNodePatch, + ParsedSingularTestPatch, UnpatchedSourceDefinition, ) from dbt.contracts.graph.unparsed import ( @@ -27,6 +28,7 @@ UnparsedMacroUpdate, UnparsedModelUpdate, UnparsedNodeUpdate, + UnparsedSingularTestUpdate, UnparsedSourceDefinition, ) from dbt.events.types import ( @@ -221,6 +223,10 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None: parser = MacroPatchParser(self, yaml_block, "macros") parser.parse() + if "data_tests" in dct: + parser = SingularTestPatchParser(self, yaml_block, "data_tests") + parser.parse() + # PatchParser.parse() (but never test_blocks) if "analyses" in dct: parser = AnalysisPatchParser(self, yaml_block, "analyses") @@ -266,7 +272,9 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None: saved_query_parser.parse() -Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch) +Parsed = TypeVar( + "Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch, ParsedSingularTestPatch +) NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedModelUpdate) NonSourceTarget = TypeVar( "NonSourceTarget", @@ -274,6 +282,7 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None: UnparsedAnalysisUpdate, UnparsedMacroUpdate, UnparsedModelUpdate, + UnparsedSingularTestUpdate, ) @@ -1055,6 +1064,46 @@ def _target_type(self) -> Type[UnparsedAnalysisUpdate]: return UnparsedAnalysisUpdate +class SingularTestPatchParser(PatchParser[UnparsedSingularTestUpdate, ParsedSingularTestPatch]): + def get_block(self, node: UnparsedSingularTestUpdate) -> TargetBlock: + return TargetBlock.from_yaml_block(self.yaml, node) + + def _target_type(self) -> Type[UnparsedSingularTestUpdate]: + return UnparsedSingularTestUpdate + + def parse_patch(self, block: TargetBlock[UnparsedSingularTestUpdate], refs: ParserRef) -> None: + patch = ParsedSingularTestPatch( + name=block.target.name, + description=block.target.description, + meta=block.target.meta, + docs=block.target.docs, + config=block.target.config, + original_file_path=block.target.original_file_path, + yaml_key=block.target.yaml_key, + package_name=block.target.package_name, + ) + + assert isinstance(self.yaml.file, SchemaSourceFile) + source_file: SchemaSourceFile = self.yaml.file + + unique_id = self.manifest.singular_test_lookup.get_unique_id( + block.name, block.target.package_name + ) + assert unique_id is not None + node = self.manifest.nodes.get(unique_id) + assert node is not None + + source_file.append_patch(patch.yaml_key, unique_id) + if patch.config: + self.patch_node_config(node, patch) + + node.patch_path = patch.file_id + node.description = patch.description + node.created_at = time.time() + node.meta = patch.meta + node.docs = patch.docs + + class MacroPatchParser(PatchParser[UnparsedMacroUpdate, ParsedMacroPatch]): def get_block(self, node: UnparsedMacroUpdate) -> TargetBlock: return TargetBlock.from_yaml_block(self.yaml, node)