diff --git a/.changes/unreleased/Features-20240531-150816.yaml b/.changes/unreleased/Features-20240531-150816.yaml new file mode 100644 index 00000000000..ebe69c0c5e3 --- /dev/null +++ b/.changes/unreleased/Features-20240531-150816.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Update data_test to accept arbitrary config options +time: 2024-05-31T15:08:16.431966-05:00 +custom: + Author: McKnight-42 + Issue: "10197" diff --git a/.changes/unreleased/Fixes-20240113-073615.yaml b/.changes/unreleased/Fixes-20240113-073615.yaml new file mode 100644 index 00000000000..3dd68508db8 --- /dev/null +++ b/.changes/unreleased/Fixes-20240113-073615.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Convert "Skipping model due to fail_fast" message to DEBUG level +time: 2024-01-13T07:36:15.836294-00:00 +custom: + Author: scottgigante,nevdelap + Issue: "8774" diff --git a/.changes/unreleased/Fixes-20240605-111652.yaml b/.changes/unreleased/Fixes-20240605-111652.yaml new file mode 100644 index 00000000000..25c756db86b --- /dev/null +++ b/.changes/unreleased/Fixes-20240605-111652.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix issues with selectors and inline nodes +time: 2024-06-05T11:16:52.187667-04:00 +custom: + Author: gshank + Issue: 8943 9269 diff --git a/.changes/unreleased/Fixes-20240607-134648.yaml b/.changes/unreleased/Fixes-20240607-134648.yaml new file mode 100644 index 00000000000..f40b98678f9 --- /dev/null +++ b/.changes/unreleased/Fixes-20240607-134648.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix snapshot config to work in yaml files +time: 2024-06-07T13:46:48.383215-04:00 +custom: + Author: gshank + Issue: "4000" diff --git a/core/dbt/artifacts/resources/v1/snapshot.py b/core/dbt/artifacts/resources/v1/snapshot.py index 6164d953184..dee235a19df 100644 --- a/core/dbt/artifacts/resources/v1/snapshot.py +++ b/core/dbt/artifacts/resources/v1/snapshot.py @@ -18,39 +18,35 @@ class SnapshotConfig(NodeConfig): # Not using Optional because of serialization issues with a Union of str and List[str] check_cols: Union[str, List[str], None] = None - @classmethod - def validate(cls, data): - super().validate(data) - # Note: currently you can't just set these keys in schema.yml because this validation - # will fail when parsing the snapshot node. - if not data.get("strategy") or not data.get("unique_key") or not data.get("target_schema"): + def final_validate(self): + if not self.strategy or not self.unique_key or not self.target_schema: raise ValidationError( "Snapshots must be configured with a 'strategy', 'unique_key', " "and 'target_schema'." ) - if data.get("strategy") == "check": - if not data.get("check_cols"): + if self.strategy == "check": + if not self.check_cols: raise ValidationError( "A snapshot configured with the check strategy must " "specify a check_cols configuration." ) - if isinstance(data["check_cols"], str) and data["check_cols"] != "all": + if isinstance(self.check_cols, str) and self.check_cols != "all": raise ValidationError( - f"Invalid value for 'check_cols': {data['check_cols']}. " + f"Invalid value for 'check_cols': {self.check_cols}. " "Expected 'all' or a list of strings." ) - elif data.get("strategy") == "timestamp": - if not data.get("updated_at"): + elif self.strategy == "timestamp": + if not self.updated_at: raise ValidationError( "A snapshot configured with the timestamp strategy " "must specify an updated_at configuration." ) - if data.get("check_cols"): + if self.check_cols: raise ValidationError("A 'timestamp' snapshot should not have 'check_cols'") # If the strategy is not 'check' or 'timestamp' it's a custom strategy, # formerly supported with GenericSnapshotConfig - if data.get("materialized") and data.get("materialized") != "snapshot": + if self.materialized and self.materialized != "snapshot": raise ValidationError("A snapshot must have a materialized value of 'snapshot'") # Called by "calculate_node_config_dict" in ContextConfigGenerator diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index d3cfd707cbb..a74172484f3 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -289,6 +289,10 @@ def _assign_params( params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"] ) + # Handle arguments mutually exclusive with INLINE + self._assert_mutually_exclusive(params_assigned_from_default, ["SELECT", "INLINE"]) + self._assert_mutually_exclusive(params_assigned_from_default, ["SELECTOR", "INLINE"]) + # Support lower cased access for legacy code. params = set( x for x in dir(self) if not callable(getattr(self, x)) and not x.startswith("__") @@ -315,7 +319,9 @@ def _assert_mutually_exclusive( """ set_flag = None for flag in group: - flag_set_by_user = flag.lower() not in params_assigned_from_default + flag_set_by_user = ( + hasattr(self, flag) and flag.lower() not in params_assigned_from_default + ) if flag_set_by_user and set_flag: raise DbtUsageException( f"{flag.lower()}: not allowed with argument {set_flag.lower()}" diff --git a/core/dbt/parser/generic_test_builders.py b/core/dbt/parser/generic_test_builders.py index 8a4864be82e..6bca8300dae 100644 --- a/core/dbt/parser/generic_test_builders.py +++ b/core/dbt/parser/generic_test_builders.py @@ -114,7 +114,8 @@ def __init__( self.package_name: str = package_name self.target: Testable = target self.version: Optional[NodeVersion] = version - + self.render_ctx: Dict[str, Any] = render_ctx + self.column_name: Optional[str] = column_name self.args["model"] = self.build_model_str() match = self.TEST_NAME_PATTERN.match(test_name) @@ -125,39 +126,12 @@ def __init__( self.name: str = groups["test_name"] self.namespace: str = groups["test_namespace"] self.config: Dict[str, Any] = {} + # Process legacy args + self.config.update(self._process_legacy_args()) - # This code removes keys identified as config args from the test entry - # dictionary. The keys remaining in the 'args' dictionary will be - # "kwargs", or keyword args that are passed to the test macro. - # The "kwargs" are not rendered into strings until compilation time. - # The "configs" are rendered here (since they were not rendered back - # in the 'get_key_dicts' methods in the schema parsers). - for key in self.CONFIG_ARGS: - value = self.args.pop(key, None) - # 'modifier' config could be either top level arg or in config - if value and "config" in self.args and key in self.args["config"]: - raise SameKeyNestedError() - if not value and "config" in self.args: - value = self.args["config"].pop(key, None) - if isinstance(value, str): - - try: - value = get_rendered(value, render_ctx, native=True) - except UndefinedMacroError as e: - - raise CustomMacroPopulatingConfigValueError( - target_name=self.target.name, - column_name=column_name, - name=self.name, - key=key, - err_msg=e.msg, - ) - - if value is not None: - self.config[key] = value - + # Process config args if present if "config" in self.args: - del self.args["config"] + self.config.update(self._render_values(self.args.pop("config", {}))) if self.namespace is not None: self.package_name = self.namespace @@ -182,6 +156,36 @@ def __init__( if short_name != full_name and "alias" not in self.config: self.config["alias"] = short_name + def _process_legacy_args(self): + config = {} + for key in self.CONFIG_ARGS: + value = self.args.pop(key, None) + if value and "config" in self.args and key in self.args["config"]: + raise SameKeyNestedError() + if not value and "config" in self.args: + value = self.args["config"].pop(key, None) + config[key] = value + + return self._render_values(config) + + def _render_values(self, config: Dict[str, Any]) -> Dict[str, Any]: + rendered_config = {} + for key, value in config.items(): + if isinstance(value, str): + try: + value = get_rendered(value, self.render_ctx, native=True) + except UndefinedMacroError as e: + raise CustomMacroPopulatingConfigValueError( + target_name=self.target.name, + column_name=self.column_name, + name=self.name, + key=key, + err_msg=e.msg, + ) + if value is not None: + rendered_config[key] = value + return rendered_config + def _bad_type(self) -> TypeError: return TypeError('invalid target type "{}"'.format(type(self.target))) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 1ac2ca6acf2..74e8f226ab2 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -467,6 +467,7 @@ def load(self) -> Manifest: self.process_model_inferred_primary_keys() self.check_valid_group_config() self.check_valid_access_property() + self.check_valid_snapshot_config() semantic_manifest = SemanticManifest(self.manifest) if not semantic_manifest.validate(): @@ -1345,6 +1346,16 @@ def check_valid_access_property(self): materialization=node.get_materialization(), ) + def check_valid_snapshot_config(self): + # Snapshot config can be set in either SQL files or yaml files, + # so we need to validate afterward. + for node in self.manifest.nodes.values(): + if node.resource_type != NodeType.Snapshot: + continue + if node.created_at < self.started_at: + continue + node.config.final_validate() + def write_perf_info(self, target_path: str): path = os.path.join(target_path, PERF_INFO_FILE_NAME) write_file(path, json.dumps(self._perf_info, cls=dbt.utils.JSONEncoder, indent=4)) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 5470c67c68d..d2460852fc5 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -104,6 +104,12 @@ def _runtime_initialize(self): ) sql_node = block_parser.parse_remote(self.args.inline, "inline_query") process_node(self.config, self.manifest, sql_node) + # Special hack to remove disabled, if it's there. This would only happen + # if all models are disabled in dbt_project + if sql_node.config.enabled is False: + sql_node.config.enabled = True + self.manifest.disabled.pop(sql_node.unique_id) + self.manifest.nodes[sql_node.unique_id] = sql_node # keep track of the node added to the manifest self._inline_node_id = sql_node.unique_id except CompilationError as exc: diff --git a/core/dbt/task/printer.py b/core/dbt/task/printer.py index 953a967b4a2..7bedbfaba93 100644 --- a/core/dbt/task/printer.py +++ b/core/dbt/task/printer.py @@ -13,6 +13,7 @@ StatsLine, ) from dbt.node_types import NodeType +from dbt_common.events.base_types import EventLevel from dbt_common.events.format import pluralize from dbt_common.events.functions import fire_event from dbt_common.events.types import Formatting @@ -68,14 +69,13 @@ def print_run_status_line(results) -> None: def print_run_result_error(result, newline: bool = True, is_warning: bool = False) -> None: - if newline: - fire_event(Formatting("")) - # set node_info for logging events node_info = None if hasattr(result, "node") and result.node: node_info = result.node.node_info if result.status == NodeStatus.Fail or (is_warning and result.status == NodeStatus.Warn): + if newline: + fire_event(Formatting("")) if is_warning: fire_event( RunResultWarning( @@ -112,8 +112,13 @@ def print_run_result_error(result, newline: bool = True, is_warning: bool = Fals fire_event( CheckNodeTestFailure(relation_name=result.node.relation_name, node_info=node_info) ) - + elif result.status == NodeStatus.Skipped and result.message is not None: + if newline: + fire_event(Formatting(""), level=EventLevel.DEBUG) + fire_event(RunResultError(msg=result.message), level=EventLevel.DEBUG) elif result.message is not None: + if newline: + fire_event(Formatting("")) fire_event(RunResultError(msg=result.message, node_info=node_info)) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index a01e7a06c22..6afbcf8597e 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -5,7 +5,7 @@ from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool from pathlib import Path -from typing import AbstractSet, Dict, Iterable, List, Optional, Set, Tuple +from typing import AbstractSet, Dict, Iterable, List, Optional, Set, Tuple, Union import dbt.exceptions import dbt.tracking @@ -108,7 +108,11 @@ def exclusion_arg(self): def get_selection_spec(self) -> SelectionSpec: default_selector_name = self.config.get_default_selector_name() - if self.args.selector: + spec: Union[SelectionSpec, bool] + if hasattr(self.args, "inline") and self.args.inline: + # We want an empty selection spec. + spec = parse_difference(None, None) + elif self.args.selector: # use pre-defined selector (--selector) spec = self.config.get_selector(self.args.selector) elif not (self.selection_arg or self.exclusion_arg) and default_selector_name: diff --git a/core/setup.py b/core/setup.py index c6c81dffa43..56572111054 100644 --- a/core/setup.py +++ b/core/setup.py @@ -69,7 +69,7 @@ # Accept patches but avoid automatically updating past a set minor version range. "dbt-extractor>=0.5.0,<=0.6", "minimal-snowplow-tracker>=0.0.2,<0.1", - "dbt-semantic-interfaces>=0.5.1,<0.6", + "dbt-semantic-interfaces>=0.5.1,<0.7", # Minor versions for these are expected to be backwards-compatible "dbt-common>=1.3.0,<2.0", "dbt-adapters>=1.1.1,<2.0", diff --git a/tests/functional/configs/test_configs.py b/tests/functional/configs/test_configs.py index 8d520f1ff80..2bbfac85c5c 100644 --- a/tests/functional/configs/test_configs.py +++ b/tests/functional/configs/test_configs.py @@ -2,7 +2,6 @@ import pytest -from dbt.exceptions import ParsingError from dbt.tests.util import ( check_relations_equal, run_dbt, @@ -120,7 +119,7 @@ def test_snapshots_materialization_proj_config(self, project): snapshots_dir = os.path.join(project.project_root, "snapshots") write_file(simple_snapshot, snapshots_dir, "mysnapshot.sql") - with pytest.raises(ParsingError): + with pytest.raises(ValidationError): run_dbt() diff --git a/tests/functional/dbt_runner/test_dbt_runner.py b/tests/functional/dbt_runner/test_dbt_runner.py index 80b94b9c73a..0b1607a2eba 100644 --- a/tests/functional/dbt_runner/test_dbt_runner.py +++ b/tests/functional/dbt_runner/test_dbt_runner.py @@ -36,6 +36,9 @@ def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: res = dbt.invoke(["deps", "--warn-error", "--warn-error-options", '{"include": "all"}']) assert type(res.exception) == DbtUsageException + res = dbt.invoke(["compile", "--select", "models", "--inline", "select 1 as id"]) + assert type(res.exception) == DbtUsageException + def test_invalid_command(self, dbt: dbtRunner) -> None: res = dbt.invoke(["invalid-command"]) assert type(res.exception) == DbtUsageException diff --git a/tests/functional/dependencies/test_dependency_options.py b/tests/functional/dependencies/test_dependency_options.py index 067fd7bf1e5..bf176831aaa 100644 --- a/tests/functional/dependencies/test_dependency_options.py +++ b/tests/functional/dependencies/test_dependency_options.py @@ -39,7 +39,7 @@ def test_deps_lock(self, clean_start): - package: fivetran/fivetran_utils version: 0.4.7 - package: dbt-labs/dbt_utils - version: 1.1.1 + version: 1.2.0 sha1_hash: 71304bca2138cf8004070b3573a1e17183c0c1a8 """ ) @@ -56,7 +56,7 @@ def test_deps_default(self, clean_start): - package: fivetran/fivetran_utils version: 0.4.7 - package: dbt-labs/dbt_utils - version: 1.1.1 + version: 1.2.0 sha1_hash: 71304bca2138cf8004070b3573a1e17183c0c1a8 """ ) diff --git a/tests/functional/graph_selection/test_inline.py b/tests/functional/graph_selection/test_inline.py new file mode 100644 index 00000000000..bf01ec8ae6a --- /dev/null +++ b/tests/functional/graph_selection/test_inline.py @@ -0,0 +1,64 @@ +import pytest + +from dbt.cli.exceptions import DbtUsageException +from dbt.tests.util import run_dbt, run_dbt_and_capture, write_file + +selectors_yml = """ + selectors: + - name: test_selector + description: Exclude everything + default: true + definition: + method: package + value: "foo" + """ + +dbt_project_yml = """ +name: test +profile: test +flags: + send_anonymous_usage_stats: false +""" + +dbt_project_yml_disabled_models = """ +name: test +profile: test +flags: + send_anonymous_usage_stats: false +models: + +enabled: false +""" + + +class TestCompileInlineWithSelector: + @pytest.fixture(scope="class") + def models(self): + return { + "first_model.sql": "select 1 as id", + } + + @pytest.fixture(scope="class") + def selectors(self): + return selectors_yml + + def test_inline_selectors(self, project): + (results, log_output) = run_dbt_and_capture( + ["compile", "--inline", "select * from {{ ref('first_model') }}"] + ) + assert len(results) == 1 + assert "Compiled inline node is:" in log_output + + # Set all models to disabled, check that we still get inline result + write_file(dbt_project_yml_disabled_models, project.project_root, "dbt_project.yml") + (results, log_output) = run_dbt_and_capture(["compile", "--inline", "select 1 as id"]) + assert len(results) == 1 + + # put back non-disabled dbt_project and check for mutually exclusive error message + # for --select and --inline + write_file(dbt_project_yml, project.project_root, "dbt_project.yml") + with pytest.raises(DbtUsageException): + run_dbt(["compile", "--select", "first_model", "--inline", "select 1 as id"]) + + # check for mutually exclusive --selector and --inline + with pytest.raises(DbtUsageException): + run_dbt(["compile", "--selector", "test_selector", "--inline", "select 1 as id"]) diff --git a/tests/functional/minimal_cli/test_minimal_cli.py b/tests/functional/minimal_cli/test_minimal_cli.py index d47f8b911c5..c757b43d4b3 100644 --- a/tests/functional/minimal_cli/test_minimal_cli.py +++ b/tests/functional/minimal_cli/test_minimal_cli.py @@ -53,6 +53,38 @@ def test_build(self, runner, project): assert "SKIP=1" in result.output +class TestBuildFailFast(BaseConfigProject): + def test_build(self, runner, project): + runner.invoke(cli, ["deps"]) + result = runner.invoke(cli, ["build", "--fail-fast"]) + # 1 seed, 1 model, 2 data tests + assert "PASS=4" in result.output + # 2 data tests + assert "ERROR=2" in result.output + # Singular test + assert "WARN=1" in result.output + # 1 snapshot + assert "SKIP=1" in result.output + # Skipping due to fail_fast is not shown when --debug is not specified. + assert "Skipping due to fail_fast" not in result.output + + +class TestBuildFailFastDebug(BaseConfigProject): + def test_build(self, runner, project): + runner.invoke(cli, ["deps"]) + result = runner.invoke(cli, ["build", "--fail-fast", "--debug"]) + # 1 seed, 1 model, 2 data tests + assert "PASS=4" in result.output + # 2 data tests + assert "ERROR=2" in result.output + # Singular test + assert "WARN=1" in result.output + # 1 snapshot + assert "SKIP=1" in result.output + # Skipping due to fail_fast is shown when --debug is specified. + assert "Skipping due to fail_fast" in result.output + + class TestDocsGenerate(BaseConfigProject): def test_docs_generate(self, runner, project): runner.invoke(cli, ["deps"]) diff --git a/tests/functional/schema_tests/data_test_config.py b/tests/functional/schema_tests/data_test_config.py new file mode 100644 index 00000000000..377f14aac04 --- /dev/null +++ b/tests/functional/schema_tests/data_test_config.py @@ -0,0 +1,115 @@ +import re + +import pytest + +from dbt.exceptions import CompilationError +from dbt.tests.util import get_manifest, run_dbt +from tests.functional.schema_tests.fixtures import ( + custom_config_yml, + mixed_config_yml, + same_key_error_yml, + seed_csv, + table_sql, +) + + +class BaseDataTestsConfig: + @pytest.fixture(scope="class") + def seeds(self): + return {"seed.csv": seed_csv} + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "config-version": 2, + } + + @pytest.fixture(scope="class", autouse=True) + def setUp(self, project): + run_dbt(["seed"]) + + +class TestCustomDataTestConfig(BaseDataTestsConfig): + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "custom_config.yml": custom_config_yml} + + def test_custom_config(self, project): + run_dbt(["parse"]) + manifest = get_manifest(project.project_root) + + # Pattern to match the test_id without the specific suffix + pattern = re.compile(r"test\.test\.accepted_values_table_color__blue__red\.\d+") + + # Find the test_id dynamically + test_id = None + for node_id in manifest.nodes: + if pattern.match(node_id): + test_id = node_id + break + + # Ensure the test_id was found + assert ( + test_id is not None + ), "Test ID matching the pattern was not found in the manifest nodes" + + # Proceed with the assertions + test_node = manifest.nodes[test_id] + assert "custom_config_key" in test_node.config + assert test_node.config["custom_config_key"] == "some_value" + + +class TestMixedDataTestConfig(BaseDataTestsConfig): + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "mixed_config.yml": mixed_config_yml} + + def test_mixed_config(self, project): + run_dbt(["parse"]) + manifest = get_manifest(project.project_root) + + # Pattern to match the test_id without the specific suffix + pattern = re.compile(r"test\.test\.accepted_values_table_color__blue__red\.\d+") + + # Find the test_id dynamically + test_id = None + for node_id in manifest.nodes: + if pattern.match(node_id): + test_id = node_id + break + + # Ensure the test_id was found + assert ( + test_id is not None + ), "Test ID matching the pattern was not found in the manifest nodes" + + # Proceed with the assertions + test_node = manifest.nodes[test_id] + assert "custom_config_key" in test_node.config + assert test_node.config["custom_config_key"] == "some_value" + assert "severity" in test_node.config + assert test_node.config["severity"] == "warn" + + +class TestSameKeyErrorDataTestConfig: + @pytest.fixture(scope="class") + def models(self): + return {"table.sql": table_sql, "same_key_error.yml": same_key_error_yml} + + def test_same_key_error(self, project): + """ + Test that verifies dbt raises a CompilationError when the test configuration + contains the same key at the top level and inside the config dictionary. + """ + # Run dbt and expect a CompilationError due to the invalid configuration + with pytest.raises(CompilationError) as exc_info: + run_dbt(["parse"]) + + # Extract the exception message + exception_message = str(exc_info.value) + + # Assert that the error message contains the expected text + assert "Test cannot have the same key at the top-level and in config" in exception_message + + # Assert that the error message contains the context of the error + assert "models/same_key_error.yml" in exception_message diff --git a/tests/functional/schema_tests/fixtures.py b/tests/functional/schema_tests/fixtures.py index 51ae067bd84..bf16148e0c7 100644 --- a/tests/functional/schema_tests/fixtures.py +++ b/tests/functional/schema_tests/fixtures.py @@ -1273,3 +1273,63 @@ data_tests: - my_custom_test """ + +custom_config_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + config: + custom_config_key: some_value +""" + +mixed_config_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + severity: warn + config: + custom_config_key: some_value +""" + +same_key_error_yml = """ +version: 2 +models: + - name: table + columns: + - name: color + data_tests: + - accepted_values: + values: ['blue', 'red'] + severity: warn + config: + severity: error +""" + +seed_csv = """ +id,color,value +1,blue,10 +2,red,20 +3,green,30 +4,yellow,40 +5,blue,50 +6,red,60 +7,blue,70 +8,green,80 +9,yellow,90 +10,blue,100 +""" + +table_sql = """ +-- content of the table.sql +select * from {{ ref('seed') }} +""" diff --git a/tests/functional/simple_snapshot/fixtures.py b/tests/functional/simple_snapshot/fixtures.py index 04e4905d4cb..a94f0c04875 100644 --- a/tests/functional/simple_snapshot/fixtures.py +++ b/tests/functional/simple_snapshot/fixtures.py @@ -86,7 +86,6 @@ models__schema_yml = """ -version: 2 snapshots: - name: snapshot_actual data_tests: @@ -97,7 +96,6 @@ """ models__schema_with_target_schema_yml = """ -version: 2 snapshots: - name: snapshot_actual data_tests: diff --git a/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py b/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py index 407cd15439f..46543da8f4b 100644 --- a/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py +++ b/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py @@ -1,7 +1,7 @@ import pytest -from dbt.exceptions import ParsingError from dbt.tests.util import run_dbt +from dbt_common.dataclass_schema import ValidationError from tests.functional.simple_snapshot.fixtures import ( macros__test_no_overlaps_sql, models__ref_snapshot_sql, @@ -10,7 +10,7 @@ snapshots_invalid__snapshot_sql = """ {# make sure to never name this anything with `target_schema` in the name, or the test will be invalid! #} -{% snapshot missing_field_target_underscore_schema %} +{% snapshot snapshot_actual %} {# missing the mandatory target_schema parameter #} {{ config( @@ -44,7 +44,10 @@ def macros(): def test_missing_strategy(project): - with pytest.raises(ParsingError) as exc: + with pytest.raises(ValidationError) as exc: run_dbt(["compile"], expect_pass=False) - assert "Snapshots must be configured with a 'strategy'" in str(exc.value) + assert ( + "Snapshots must be configured with a 'strategy', 'unique_key', and 'target_schema'" + in str(exc.value) + ) diff --git a/tests/functional/simple_snapshot/test_snapshot_config.py b/tests/functional/simple_snapshot/test_snapshot_config.py new file mode 100644 index 00000000000..5124cf9c38b --- /dev/null +++ b/tests/functional/simple_snapshot/test_snapshot_config.py @@ -0,0 +1,67 @@ +import pytest + +from dbt.tests.util import run_dbt, write_file + +orders_sql = """ +select 1 as id, 101 as user_id, 'pending' as status +""" + +snapshot_sql = """ +{% snapshot orders_snapshot %} + +{{ + config( + target_schema=schema, + strategy='check', + unique_key='id', + check_cols=['status'], + ) +}} + +select * from {{ ref('orders') }} + +{% endsnapshot %} +""" + +snapshot_no_config_sql = """ +{% snapshot orders_snapshot %} + +select * from {{ ref('orders') }} + +{% endsnapshot %} +""" + +snapshot_schema_yml = """ +snapshots: + - name: orders_snapshot + config: + target_schema: test + strategy: check + unique_key: id + check_cols: ['status'] +""" + + +class TestSnapshotConfig: + @pytest.fixture(scope="class") + def models(self): + return {"orders.sql": orders_sql} + + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot_orders.sql": snapshot_sql} + + def test_config(self, project): + run_dbt(["run"]) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + # try to parse with config in schema file + write_file( + snapshot_no_config_sql, project.project_root, "snapshots", "snapshot_orders.sql" + ) + write_file(snapshot_schema_yml, project.project_root, "snapshots", "snapshot.yml") + results = run_dbt(["parse"]) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 diff --git a/tests/unit/README.md b/tests/unit/README.md index b9e22124c85..e063e370b48 100644 --- a/tests/unit/README.md +++ b/tests/unit/README.md @@ -1,6 +1,5 @@ # Unit test README -## test_contracts_graph_parsed.py ### The Why We need to ensure that we can go from objects to dictionaries and back without any @@ -16,3 +15,7 @@ versions of the object we're interested in testing, and run the different genera of the object through the test. This gives us confidence that for any allowable configuration of an object, state is not changed when moving back and forth betweeen the python object version and the seralized version. + +### The What + +- We test concrete classes in the codebase and do not test abstract classes as they are implementation details. [reference](https://enterprisecraftsmanship.com/posts/how-to-unit-test-an-abstract-class/) diff --git a/tests/unit/contracts/graph/test_nodes_parsed.py b/tests/unit/contracts/graph/test_nodes_parsed.py index 4e8392cb3df..bce860e63d3 100644 --- a/tests/unit/contracts/graph/test_nodes_parsed.py +++ b/tests/unit/contracts/graph/test_nodes_parsed.py @@ -1333,7 +1333,7 @@ def test_invalid_missing_updated_at(basic_timestamp_snapshot_config_dict): bad_fields = basic_timestamp_snapshot_config_dict del bad_fields["updated_at"] bad_fields["check_cols"] = "all" - assert_fails_validation(bad_fields, SnapshotConfig) + assert_snapshot_config_fails_validation(bad_fields) @pytest.fixture @@ -1437,7 +1437,7 @@ def test_complex_snapshot_config( def test_invalid_check_wrong_strategy(basic_check_snapshot_config_dict): wrong_strategy = basic_check_snapshot_config_dict wrong_strategy["strategy"] = "timestamp" - assert_fails_validation(wrong_strategy, SnapshotConfig) + assert_snapshot_config_fails_validation(wrong_strategy) def test_invalid_missing_check_cols(basic_check_snapshot_config_dict): @@ -1445,6 +1445,8 @@ def test_invalid_missing_check_cols(basic_check_snapshot_config_dict): del wrong_fields["check_cols"] with pytest.raises(ValidationError, match=r"A snapshot configured with the check strategy"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() def test_missing_snapshot_configs(basic_check_snapshot_config_dict): @@ -1452,22 +1454,35 @@ def test_missing_snapshot_configs(basic_check_snapshot_config_dict): del wrong_fields["strategy"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() wrong_fields["strategy"] = "timestamp" del wrong_fields["unique_key"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() wrong_fields["unique_key"] = "id" del wrong_fields["target_schema"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() + + +def assert_snapshot_config_fails_validation(dct): + with pytest.raises(ValidationError): + SnapshotConfig.validate(dct) + obj = SnapshotConfig.from_dict(dct) + obj.final_validate() def test_invalid_check_value(basic_check_snapshot_config_dict): invalid_check_type = basic_check_snapshot_config_dict invalid_check_type["check_cols"] = "some" - assert_fails_validation(invalid_check_type, SnapshotConfig) + assert_snapshot_config_fails_validation(invalid_check_type) @pytest.fixture diff --git a/tests/unit/task/test_clone.py b/tests/unit/task/test_clone.py new file mode 100644 index 00000000000..9ef07d5e492 --- /dev/null +++ b/tests/unit/task/test_clone.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock, patch + +from dbt.flags import get_flags +from dbt.task.clone import CloneTask + + +def test_clone_task_not_preserve_edges(): + mock_node_selector = MagicMock() + mock_spec = MagicMock() + with patch.object( + CloneTask, "get_node_selector", return_value=mock_node_selector + ), patch.object(CloneTask, "get_selection_spec", return_value=mock_spec): + task = CloneTask(get_flags(), None, None) + task.get_graph_queue() + # when we get the graph queue, preserve_edges is False + mock_node_selector.get_graph_queue.assert_called_with(mock_spec, False) diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py new file mode 100644 index 00000000000..c689e8f41aa --- /dev/null +++ b/tests/unit/task/test_run.py @@ -0,0 +1,52 @@ +from argparse import Namespace +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.config.runtime import RuntimeConfig +from dbt.flags import get_flags, set_from_args +from dbt.task.run import RunTask +from dbt.tests.util import safe_set_invocation_context + + +@pytest.mark.parametrize( + "exception_to_raise, expected_cancel_connections", + [ + (SystemExit, True), + (KeyboardInterrupt, True), + (Exception, False), + ], +) +def test_run_task_cancel_connections( + exception_to_raise, expected_cancel_connections, runtime_config: RuntimeConfig +): + safe_set_invocation_context() + + def mock_run_queue(*args, **kwargs): + raise exception_to_raise("Test exception") + + with patch.object(RunTask, "run_queue", mock_run_queue), patch.object( + RunTask, "_cancel_connections" + ) as mock_cancel_connections: + + set_from_args(Namespace(write_json=False), None) + task = RunTask( + get_flags(), + runtime_config, + None, + ) + with pytest.raises(exception_to_raise): + task.execute_nodes() + assert mock_cancel_connections.called == expected_cancel_connections + + +def test_run_task_preserve_edges(): + mock_node_selector = MagicMock() + mock_spec = MagicMock() + with patch.object(RunTask, "get_node_selector", return_value=mock_node_selector), patch.object( + RunTask, "get_selection_spec", return_value=mock_spec + ): + task = RunTask(get_flags(), None, None) + task.get_graph_queue() + # when we get the graph queue, preserve_edges is True + mock_node_selector.get_graph_queue.assert_called_with(mock_spec, True) diff --git a/tests/unit/task/test_runnable.py b/tests/unit/task/test_runnable.py deleted file mode 100644 index 17e09830892..00000000000 --- a/tests/unit/task/test_runnable.py +++ /dev/null @@ -1,151 +0,0 @@ -from dataclasses import dataclass -from typing import AbstractSet, Any, Dict, List, Optional, Tuple - -import networkx as nx -import pytest - -from dbt.artifacts.resources.types import NodeType -from dbt.graph import Graph, ResourceTypeSelector -from dbt.task.runnable import GraphRunnableMode, GraphRunnableTask -from dbt.tests.util import safe_set_invocation_context -from tests.unit.utils import MockNode, make_manifest - - -@dataclass -class MockArgs: - """Simple mock args for us in a runnable task""" - - state: Optional[Dict[str, Any]] = None - defer_state: Optional[Dict[str, Any]] = None - write_json: bool = False - selector: Optional[str] = None - select: Tuple[str] = () - exclude: Tuple[str] = () - - -@dataclass -class MockConfig: - """Simple mock config for use in a RunnableTask""" - - threads: int = 1 - target_name: str = "mock_config_target_name" - - def get_default_selector_name(self): - return None - - -class MockRunnableTask(GraphRunnableTask): - def __init__( - self, - exception_class: Exception = Exception, - nodes: Optional[List[MockNode]] = None, - edges: Optional[List[Tuple[str, str]]] = None, - ): - nodes = nodes or [] - edges = edges or [] - - self.forced_exception_class = exception_class - self.did_cancel: bool = False - super().__init__(args=MockArgs(), config=MockConfig(), manifest=None) - self.manifest = make_manifest(nodes=nodes) - digraph = nx.DiGraph() - for edge in edges: - digraph.add_edge(edge[0], edge[1]) - self.graph = Graph(digraph) - - def run_queue(self, pool): - """Override `run_queue` to raise a system exit""" - raise self.forced_exception_class() - - def _cancel_connections(self, pool): - """Override `_cancel_connections` to track whether it was called""" - self.did_cancel = True - - def get_node_selector(self): - """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" - selector = ResourceTypeSelector( - graph=self.graph, - manifest=self.manifest, - previous_state=self.previous_state, - resource_types=[NodeType.Model], - include_empty_nodes=True, - ) - return selector - - def defer_to_manifest(self, adapter, selected_uids: AbstractSet[str]): - """This is an `abstract_method` on `GraphRunnableTask`, thus we must implement it""" - return None - - -class MockRunnableTaskIndependent(MockRunnableTask): - def get_run_mode(self) -> GraphRunnableMode: - return GraphRunnableMode.Independent - - -def test_graph_runnable_task_cancels_connection_on_system_exit(): - - safe_set_invocation_context() - - task = MockRunnableTask(exception_class=SystemExit) - - with pytest.raises(SystemExit): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is True - - -def test_graph_runnable_task_cancels_connection_on_keyboard_interrupt(): - - safe_set_invocation_context() - - task = MockRunnableTask(exception_class=KeyboardInterrupt) - - with pytest.raises(KeyboardInterrupt): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is True - - -def test_graph_runnable_task_doesnt_cancel_connection_on_generic_exception(): - task = MockRunnableTask(exception_class=Exception) - - with pytest.raises(Exception): - task.execute_nodes() - - # If `did_cancel` is True, that means `_cancel_connections` was called - assert task.did_cancel is False - - -def test_graph_runnable_preserves_edges_by_default(): - task = MockRunnableTask( - nodes=[ - MockNode("test", "upstream_node", fqn="model.test.upstream_node"), - MockNode("test", "downstream_node", fqn="model.test.downstream_node"), - ], - edges=[("model.test.upstream_node", "model.test.downstream_node")], - ) - assert task.get_run_mode() == GraphRunnableMode.Topological - graph_queue = task.get_graph_queue() - - assert graph_queue.queued == {"model.test.upstream_node"} - assert graph_queue.inner.queue == [(0, "model.test.upstream_node")] - - -def test_graph_runnable_preserves_edges_false(): - task = MockRunnableTaskIndependent( - nodes=[ - MockNode("test", "upstream_node", fqn="model.test.upstream_node"), - MockNode("test", "downstream_node", fqn="model.test.downstream_node"), - ], - edges=[("model.test.upstream_node", "model.test.downstream_node")], - ) - assert task.get_run_mode() == GraphRunnableMode.Independent - graph_queue = task.get_graph_queue() - - assert graph_queue.queued == {"model.test.downstream_node", "model.test.upstream_node"} - assert graph_queue.inner.queue == [ - (0, "model.test.downstream_node"), - (0, "model.test.upstream_node"), - ]