diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 2bd183af759..267abc6a23a 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -54,7 +54,6 @@ ManifestNode, Metric, ModelNode, - ResultNode, SavedQuery, SeedNode, SemanticModel, @@ -1586,7 +1585,7 @@ def add_disabled_nofile(self, node: GraphMemberNode): else: self.disabled[node.unique_id] = [node] - def add_disabled(self, source_file: AnySourceFile, node: ResultNode, test_from=None): + def add_disabled(self, source_file: AnySourceFile, node: GraphMemberNode, test_from=None): self.add_disabled_nofile(node) if isinstance(source_file, SchemaSourceFile): if isinstance(node, GenericTestNode): diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 71a4d53d429..05d5b7d114d 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -7,7 +7,7 @@ from dbt.adapters.factory import get_adapter # noqa: F401 from dbt.artifacts.resources import Contract from dbt.clients.jinja import MacroGenerator, get_rendered -from dbt.config import Project, RuntimeConfig +from dbt.config import RuntimeConfig from dbt.context.context_config import ContextConfig from dbt.context.providers import ( generate_generate_name_macro_context, @@ -39,9 +39,9 @@ class BaseParser(Generic[FinalValue]): - def __init__(self, project: Project, manifest: Manifest) -> None: - self.project = project - self.manifest = manifest + def __init__(self, project: RuntimeConfig, manifest: Manifest) -> None: + self.project: RuntimeConfig = project + self.manifest: Manifest = manifest @abc.abstractmethod def parse_file(self, block: FileBlock) -> None: @@ -63,7 +63,7 @@ def generate_unique_id(self, resource_name: str, hash: Optional[str] = None) -> class Parser(BaseParser[FinalValue], Generic[FinalValue]): def __init__( self, - project: Project, + project: RuntimeConfig, manifest: Manifest, root_project: RuntimeConfig, ) -> None: @@ -121,7 +121,7 @@ class ConfiguredParser( ): def __init__( self, - project: Project, + project: RuntimeConfig, manifest: Manifest, root_project: RuntimeConfig, ) -> None: diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index a487ba4799e..7e37a0e5417 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -222,12 +222,12 @@ class ManifestLoader: def __init__( self, root_project: RuntimeConfig, - all_projects: Mapping[str, Project], + all_projects: Mapping[str, RuntimeConfig], macro_hook: Optional[Callable[[Manifest], Any]] = None, file_diff: Optional[FileDiff] = None, ) -> None: self.root_project: RuntimeConfig = root_project - self.all_projects: Mapping[str, Project] = all_projects + self.all_projects: Mapping[str, RuntimeConfig] = all_projects self.file_diff = file_diff self.manifest: Manifest = Manifest() self.new_manifest = self.manifest @@ -669,7 +669,7 @@ def load_and_parse_macros(self, project_parser_files): # 'parser_types' def parse_project( self, - project: Project, + project: RuntimeConfig, parser_files, parser_types: List[Type[Parser]], ) -> None: diff --git a/core/dbt/parser/schema_yaml_readers.py b/core/dbt/parser/schema_yaml_readers.py index 6bd1c33b6db..86ac10d7545 100644 --- a/core/dbt/parser/schema_yaml_readers.py +++ b/core/dbt/parser/schema_yaml_readers.py @@ -31,6 +31,7 @@ generate_parse_exposure, generate_parse_semantic_models, ) +from dbt.contracts.files import SchemaSourceFile from dbt.contracts.graph.nodes import Exposure, Group, Metric, SavedQuery, SemanticModel from dbt.contracts.graph.unparsed import ( UnparsedConversionTypeParams, @@ -85,7 +86,7 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: self.schema_parser = schema_parser self.yaml = yaml - def parse_exposure(self, unparsed: UnparsedExposure): + def parse_exposure(self, unparsed: UnparsedExposure) -> None: package_name = self.project.project_name unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}" path = self.yaml.path.relative_path @@ -143,6 +144,7 @@ def parse_exposure(self, unparsed: UnparsedExposure): get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True) # parsed now has a populated refs/sources/metrics + assert isinstance(self.yaml.file, SchemaSourceFile) if parsed.config.enabled: self.manifest.add_exposure(self.yaml.file, parsed) else: @@ -171,7 +173,7 @@ def _generate_exposure_config( patch_config_dict=precedence_configs, ) - def parse(self): + def parse(self) -> None: for data in self.get_key_dicts(): try: UnparsedExposure.validate(data) @@ -387,7 +389,7 @@ def _get_metric_type_params(self, unparsed_metric: UnparsedMetric) -> MetricType # input_measures=?, ) - def parse_metric(self, unparsed: UnparsedMetric, generated: bool = False): + def parse_metric(self, unparsed: UnparsedMetric, generated: bool = False) -> None: package_name = self.project.project_name unique_id = f"{NodeType.Metric}.{package_name}.{unparsed.name}" path = self.yaml.path.relative_path @@ -442,6 +444,7 @@ def parse_metric(self, unparsed: UnparsedMetric, generated: bool = False): ) # if the metric is disabled we do not want it included in the manifest, only in the disabled dict + assert isinstance(self.yaml.file, SchemaSourceFile) if parsed.config.enabled: self.manifest.add_metric(self.yaml.file, parsed, generated) else: @@ -471,7 +474,7 @@ def _generate_metric_config( ) return config - def parse(self): + def parse(self) -> None: for data in self.get_key_dicts(): try: UnparsedMetric.validate(data) @@ -488,7 +491,7 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: self.schema_parser = schema_parser self.yaml = yaml - def parse_group(self, unparsed: UnparsedGroup): + def parse_group(self, unparsed: UnparsedGroup) -> None: package_name = self.project.project_name unique_id = f"{NodeType.Group}.{package_name}.{unparsed.name}" path = self.yaml.path.relative_path @@ -503,6 +506,7 @@ def parse_group(self, unparsed: UnparsedGroup): owner=unparsed.owner, ) + assert isinstance(self.yaml.file, SchemaSourceFile) self.manifest.add_group(self.yaml.file, parsed) def parse(self): @@ -635,7 +639,7 @@ def _generate_semantic_model_config( return config - def parse_semantic_model(self, unparsed: UnparsedSemanticModel): + def parse_semantic_model(self, unparsed: UnparsedSemanticModel) -> None: package_name = self.project.project_name unique_id = f"{NodeType.SemanticModel}.{package_name}.{unparsed.name}" path = self.yaml.path.relative_path @@ -695,6 +699,7 @@ def parse_semantic_model(self, unparsed: UnparsedSemanticModel): # if the semantic model is disabled we do not want it included in the manifest, # only in the disabled dict + assert isinstance(self.yaml.file, SchemaSourceFile) if parsed.config.enabled: self.manifest.add_semantic_model(self.yaml.file, parsed) else: @@ -705,7 +710,7 @@ def parse_semantic_model(self, unparsed: UnparsedSemanticModel): if measure.create_metric is True: self._create_metric(measure=measure, enabled=parsed.config.enabled) - def parse(self): + def parse(self) -> None: for data in self.get_key_dicts(): try: UnparsedSemanticModel.validate(data) @@ -831,6 +836,7 @@ def parse_saved_query(self, unparsed: UnparsedSavedQuery) -> None: delattr(export, "relation_name") # Only add thes saved query if it's enabled, otherwise we track it with other diabled nodes + assert isinstance(self.yaml.file, SchemaSourceFile) if parsed.config.enabled: self.manifest.add_saved_query(self.yaml.file, parsed) else: diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 284a01fc58e..af912b455cb 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -6,9 +6,11 @@ from dbt import deprecations from dbt.clients.yaml_helper import load_yaml_text +from dbt.config import RuntimeConfig from dbt.context.configured import SchemaYamlVars, generate_schema_yml_context from dbt.context.context_config import ContextConfig -from dbt.contracts.files import SchemaSourceFile +from dbt.contracts.files import SchemaSourceFile, SourceFile +from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ( ModelNode, ParsedMacroPatch, @@ -142,9 +144,9 @@ def yaml_from_file(source_file: SchemaSourceFile) -> Optional[Dict[str, Any]]: class SchemaParser(SimpleParser[YamlBlock, ModelNode]): def __init__( self, - project, - manifest, - root_project, + project: RuntimeConfig, + manifest: Manifest, + root_project: RuntimeConfig, ) -> None: super().__init__(project, manifest, root_project) @@ -282,33 +284,33 @@ class ParseResult: # PatchParser, SemanticModelParser, SavedQueryParser, UnitTestParser class YamlReader(metaclass=ABCMeta): def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock, key: str) -> None: - self.schema_parser = schema_parser + self.schema_parser: SchemaParser = schema_parser # key: models, seeds, snapshots, sources, macros, # analyses, exposures, unit_tests - self.key = key - self.yaml = yaml - self.schema_yaml_vars = SchemaYamlVars() + self.key: str = key + self.yaml: YamlBlock = yaml + self.schema_yaml_vars: SchemaYamlVars = SchemaYamlVars() self.render_ctx = generate_schema_yml_context( self.schema_parser.root_project, self.schema_parser.project.project_name, self.schema_yaml_vars, ) - self.renderer = SchemaYamlRenderer(self.render_ctx, self.key) + self.renderer: SchemaYamlRenderer = SchemaYamlRenderer(self.render_ctx, self.key) @property - def manifest(self): + def manifest(self) -> Manifest: return self.schema_parser.manifest @property - def project(self): + def project(self) -> RuntimeConfig: return self.schema_parser.project @property - def default_database(self): + def default_database(self) -> str: return self.schema_parser.default_database @property - def root_project(self): + def root_project(self) -> RuntimeConfig: return self.schema_parser.root_project # for the different schema subparsers ('models', 'source', etc) @@ -360,7 +362,7 @@ def render_entry(self, dct): return dct @abstractmethod - def parse(self) -> ParseResult: + def parse(self) -> Optional[ParseResult]: raise NotImplementedError("parse is abstract") @@ -425,7 +427,9 @@ def add_source_definitions(self, source: UnparsedSourceDefinition) -> None: fqn=fqn, name=f"{source.name}_{table.name}", ) - self.manifest.add_source(self.yaml.file, source_def) + assert isinstance(self.yaml.file, SchemaSourceFile) + source_file: SchemaSourceFile = self.yaml.file + self.manifest.add_source(source_file, source_def) # This class has two subclasses: NodePatchParser and MacroPatchParser @@ -515,7 +519,7 @@ def get_unparsed_target(self) -> Iterable[NonSourceTarget]: # We want to raise an error if some attributes are in two places, and move them # from toplevel to config if necessary - def normalize_attribute(self, data, path, attribute): + def normalize_attribute(self, data, path, attribute) -> None: if attribute in data: if "config" in data and attribute in data["config"]: raise ParsingError( @@ -529,31 +533,31 @@ def normalize_attribute(self, data, path, attribute): data["config"] = {} data["config"][attribute] = data.pop(attribute) - def normalize_meta_attribute(self, data, path): + def normalize_meta_attribute(self, data, path) -> None: return self.normalize_attribute(data, path, "meta") - def normalize_docs_attribute(self, data, path): + def normalize_docs_attribute(self, data, path) -> None: return self.normalize_attribute(data, path, "docs") - def normalize_group_attribute(self, data, path): + def normalize_group_attribute(self, data, path) -> None: return self.normalize_attribute(data, path, "group") - def normalize_contract_attribute(self, data, path): + def normalize_contract_attribute(self, data, path) -> None: return self.normalize_attribute(data, path, "contract") - def normalize_access_attribute(self, data, path): + def normalize_access_attribute(self, data, path) -> None: return self.normalize_attribute(data, path, "access") @property - def is_root_project(self): + def is_root_project(self) -> bool: if self.root_project.project_name == self.project.project_name: return True return False - def validate_data_tests(self, data): + def validate_data_tests(self, data) -> None: # Rename 'tests' -> 'data_tests' at both model-level and column-level # Raise a validation error if the user has defined both names - def validate_and_rename(data, is_root_project: bool): + def validate_and_rename(data, is_root_project: bool) -> None: if data.get("tests"): if "tests" in data and "data_tests" in data: raise ValidationError( @@ -583,7 +587,7 @@ def validate_and_rename(data, is_root_project: bool): for column in version["columns"]: validate_and_rename(column, self.is_root_project) - def patch_node_config(self, node, patch): + def patch_node_config(self, node, patch) -> None: if "access" in patch.config: if AccessType.is_valid(patch.config["access"]): patch.config["access"] = AccessType(patch.config["access"]) @@ -713,7 +717,7 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None: self.patch_node_properties(node, patch) - def patch_node_properties(self, node, patch: "ParsedNodePatch"): + def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: """Given a ParsedNodePatch, add the new information to the node.""" # explicitly pick out the parts to update so we don't inadvertently # step on the model name or anything @@ -784,7 +788,7 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) versioned_model_name, target.package_name, None ) - versioned_model_node = None + versioned_model_node: Optional[ModelNode] = None add_node_nofile_fn: Callable # If this is the latest version, it's allowed to define itself in a model file name that doesn't have a suffix @@ -808,12 +812,17 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) "in `dbt_project.yml` or in the sql files." ) raise ParsingError(msg) - versioned_model_node = self.manifest.disabled.pop( - found_nodes[0].unique_id - )[0] + # We know that there's only one node in the disabled list because + # otherwise we would have raised the error above + found_node = found_nodes[0] + self.manifest.disabled.pop(found_node.unique_id) + assert isinstance(found_node, ModelNode) + versioned_model_node = found_node add_node_nofile_fn = self.manifest.add_disabled_nofile else: - versioned_model_node = self.manifest.nodes.pop(versioned_model_unique_id) + found_node = self.manifest.nodes.pop(versioned_model_unique_id) + assert isinstance(found_node, ModelNode) + versioned_model_node = found_node add_node_nofile_fn = self.manifest.add_node_nofile if versioned_model_node is None: @@ -832,12 +841,12 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) f"model.{target.package_name}.{target.name}.{unparsed_version.formatted_v}" ) # update source file.nodes with new unique_id - self.manifest.files[versioned_model_node.file_id].nodes.remove( - versioned_model_node_unique_id_old - ) - self.manifest.files[versioned_model_node.file_id].nodes.append( - versioned_model_node.unique_id - ) + model_source_file = self.manifest.files[versioned_model_node.file_id] + assert isinstance(model_source_file, SourceFile) + # because of incomplete test setup, check before removing + if versioned_model_node_unique_id_old in model_source_file.nodes: + model_source_file.nodes.remove(versioned_model_node_unique_id_old) + model_source_file.nodes.append(versioned_model_node.unique_id) # update versioned node fqn versioned_model_node.fqn[-1] = target.name @@ -889,7 +898,7 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) def _target_type(self) -> Type[UnparsedModelUpdate]: return UnparsedModelUpdate - def patch_node_properties(self, node, patch: "ParsedNodePatch"): + def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: super().patch_node_properties(node, patch) node.version = patch.version node.latest_version = patch.latest_version @@ -906,7 +915,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch"): self.patch_constraints(node, patch.constraints) node.build_contract_checksum() - def patch_constraints(self, node, constraints): + def patch_constraints(self, node, constraints) -> None: contract_config = node.config.get("contract") if contract_config.enforced is True: self._validate_constraint_prerequisites(node) @@ -922,7 +931,9 @@ def patch_constraints(self, node, constraints): self._validate_pk_constraints(node, constraints) node.constraints = [ModelLevelConstraint.from_dict(c) for c in constraints] - def _validate_pk_constraints(self, model_node: ModelNode, constraints: List[Dict[str, Any]]): + def _validate_pk_constraints( + self, model_node: ModelNode, constraints: List[Dict[str, Any]] + ) -> None: errors = [] # check for primary key constraints defined at the column level pk_col: List[str] = [] @@ -955,7 +966,7 @@ def _validate_pk_constraints(self, model_node: ModelNode, constraints: List[Dict + "\n".join(errors) ) - def _validate_constraint_prerequisites(self, model_node: ModelNode): + def _validate_constraint_prerequisites(self, model_node: ModelNode) -> None: column_warn_unsupported = [ constraint.warn_unsupported for column in model_node.columns.values() diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 6057ea77cf3..cefe634c8c3 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -292,6 +292,7 @@ def parse(self) -> ParseResult: # for calculating state:modified unit_test_definition.build_unit_test_checksum() + assert isinstance(self.yaml.file, SchemaSourceFile) self.manifest.add_unit_test(self.yaml.file, unit_test_definition) return ParseResult() diff --git a/tests/unit/parser/test_parser.py b/tests/unit/parser/test_parser.py index 42398f48f39..20a2c9e8c83 100644 --- a/tests/unit/parser/test_parser.py +++ b/tests/unit/parser/test_parser.py @@ -162,7 +162,7 @@ def tearDown(self): self.parser_patcher.stop() self.patcher.stop() - def file_block_for(self, data: str, filename: str, searched: str): + def source_file_for(self, data: str, filename: str, searched: str): root_dir = get_abs_os_path("./dbt_packages/snowplow") filename = normalize(filename) path = FilePath( @@ -178,6 +178,10 @@ def file_block_for(self, data: str, filename: str, searched: str): project_name="snowplow", ) source_file.contents = data + return source_file + + def file_block_for(self, data: str, filename: str, searched: str): + source_file = self.source_file_for(data, filename, searched) return FileBlock(file=source_file) def assert_has_manifest_lengths( @@ -580,9 +584,11 @@ def setUp(self): sources=[], patch_path=None, ) + source_file = self.source_file_for("", "my_model.sql", "models") nodes = {my_model_node.unique_id: my_model_node} macros = {m.unique_id: m for m in generate_name_macros("root")} self.manifest = Manifest(nodes=nodes, macros=macros) + self.manifest.files[source_file.file_id] = source_file self.manifest.ref_lookup self.parser = SchemaParser( project=self.snowplow_project_config, @@ -702,6 +708,7 @@ def setUp(self): patch_path=None, file_id="snowplow://models/arbitrary_file_name.sql", ) + my_model_v1_source_file = self.source_file_for("", "arbitrary_file_name.sql", "models") my_model_v2_node = MockNode( package="snowplow", name="my_model_v2", @@ -711,12 +718,16 @@ def setUp(self): patch_path=None, file_id="snowplow://models/my_model_v2.sql", ) + my_model_v2_source_file = self.source_file_for("", "my_model_v2.sql", "models") nodes = { my_model_v1_node.unique_id: my_model_v1_node, my_model_v2_node.unique_id: my_model_v2_node, } macros = {m.unique_id: m for m in generate_name_macros("root")} - files = {node.file_id: mock.MagicMock(nodes=[node.unique_id]) for node in nodes.values()} + files = { + my_model_v1_source_file.file_id: my_model_v1_source_file, + my_model_v2_source_file.file_id: my_model_v2_source_file, + } self.manifest = Manifest(nodes=nodes, macros=macros, files=files) self.manifest.ref_lookup self.parser = SchemaParser(