diff --git a/.changes/unreleased/Fixes-20240607-134648.yaml b/.changes/unreleased/Fixes-20240607-134648.yaml new file mode 100644 index 00000000000..f40b98678f9 --- /dev/null +++ b/.changes/unreleased/Fixes-20240607-134648.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix snapshot config to work in yaml files +time: 2024-06-07T13:46:48.383215-04:00 +custom: + Author: gshank + Issue: "4000" diff --git a/core/dbt/artifacts/resources/v1/snapshot.py b/core/dbt/artifacts/resources/v1/snapshot.py index 6164d953184..dee235a19df 100644 --- a/core/dbt/artifacts/resources/v1/snapshot.py +++ b/core/dbt/artifacts/resources/v1/snapshot.py @@ -18,39 +18,35 @@ class SnapshotConfig(NodeConfig): # Not using Optional because of serialization issues with a Union of str and List[str] check_cols: Union[str, List[str], None] = None - @classmethod - def validate(cls, data): - super().validate(data) - # Note: currently you can't just set these keys in schema.yml because this validation - # will fail when parsing the snapshot node. - if not data.get("strategy") or not data.get("unique_key") or not data.get("target_schema"): + def final_validate(self): + if not self.strategy or not self.unique_key or not self.target_schema: raise ValidationError( "Snapshots must be configured with a 'strategy', 'unique_key', " "and 'target_schema'." ) - if data.get("strategy") == "check": - if not data.get("check_cols"): + if self.strategy == "check": + if not self.check_cols: raise ValidationError( "A snapshot configured with the check strategy must " "specify a check_cols configuration." ) - if isinstance(data["check_cols"], str) and data["check_cols"] != "all": + if isinstance(self.check_cols, str) and self.check_cols != "all": raise ValidationError( - f"Invalid value for 'check_cols': {data['check_cols']}. " + f"Invalid value for 'check_cols': {self.check_cols}. " "Expected 'all' or a list of strings." ) - elif data.get("strategy") == "timestamp": - if not data.get("updated_at"): + elif self.strategy == "timestamp": + if not self.updated_at: raise ValidationError( "A snapshot configured with the timestamp strategy " "must specify an updated_at configuration." ) - if data.get("check_cols"): + if self.check_cols: raise ValidationError("A 'timestamp' snapshot should not have 'check_cols'") # If the strategy is not 'check' or 'timestamp' it's a custom strategy, # formerly supported with GenericSnapshotConfig - if data.get("materialized") and data.get("materialized") != "snapshot": + if self.materialized and self.materialized != "snapshot": raise ValidationError("A snapshot must have a materialized value of 'snapshot'") # Called by "calculate_node_config_dict" in ContextConfigGenerator diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 1ac2ca6acf2..74e8f226ab2 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -467,6 +467,7 @@ def load(self) -> Manifest: self.process_model_inferred_primary_keys() self.check_valid_group_config() self.check_valid_access_property() + self.check_valid_snapshot_config() semantic_manifest = SemanticManifest(self.manifest) if not semantic_manifest.validate(): @@ -1345,6 +1346,16 @@ def check_valid_access_property(self): materialization=node.get_materialization(), ) + def check_valid_snapshot_config(self): + # Snapshot config can be set in either SQL files or yaml files, + # so we need to validate afterward. + for node in self.manifest.nodes.values(): + if node.resource_type != NodeType.Snapshot: + continue + if node.created_at < self.started_at: + continue + node.config.final_validate() + def write_perf_info(self, target_path: str): path = os.path.join(target_path, PERF_INFO_FILE_NAME) write_file(path, json.dumps(self._perf_info, cls=dbt.utils.JSONEncoder, indent=4)) diff --git a/tests/functional/configs/test_configs.py b/tests/functional/configs/test_configs.py index 8d520f1ff80..2bbfac85c5c 100644 --- a/tests/functional/configs/test_configs.py +++ b/tests/functional/configs/test_configs.py @@ -2,7 +2,6 @@ import pytest -from dbt.exceptions import ParsingError from dbt.tests.util import ( check_relations_equal, run_dbt, @@ -120,7 +119,7 @@ def test_snapshots_materialization_proj_config(self, project): snapshots_dir = os.path.join(project.project_root, "snapshots") write_file(simple_snapshot, snapshots_dir, "mysnapshot.sql") - with pytest.raises(ParsingError): + with pytest.raises(ValidationError): run_dbt() diff --git a/tests/functional/simple_snapshot/fixtures.py b/tests/functional/simple_snapshot/fixtures.py index 04e4905d4cb..a94f0c04875 100644 --- a/tests/functional/simple_snapshot/fixtures.py +++ b/tests/functional/simple_snapshot/fixtures.py @@ -86,7 +86,6 @@ models__schema_yml = """ -version: 2 snapshots: - name: snapshot_actual data_tests: @@ -97,7 +96,6 @@ """ models__schema_with_target_schema_yml = """ -version: 2 snapshots: - name: snapshot_actual data_tests: diff --git a/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py b/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py index 407cd15439f..46543da8f4b 100644 --- a/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py +++ b/tests/functional/simple_snapshot/test_missing_strategy_snapshot.py @@ -1,7 +1,7 @@ import pytest -from dbt.exceptions import ParsingError from dbt.tests.util import run_dbt +from dbt_common.dataclass_schema import ValidationError from tests.functional.simple_snapshot.fixtures import ( macros__test_no_overlaps_sql, models__ref_snapshot_sql, @@ -10,7 +10,7 @@ snapshots_invalid__snapshot_sql = """ {# make sure to never name this anything with `target_schema` in the name, or the test will be invalid! #} -{% snapshot missing_field_target_underscore_schema %} +{% snapshot snapshot_actual %} {# missing the mandatory target_schema parameter #} {{ config( @@ -44,7 +44,10 @@ def macros(): def test_missing_strategy(project): - with pytest.raises(ParsingError) as exc: + with pytest.raises(ValidationError) as exc: run_dbt(["compile"], expect_pass=False) - assert "Snapshots must be configured with a 'strategy'" in str(exc.value) + assert ( + "Snapshots must be configured with a 'strategy', 'unique_key', and 'target_schema'" + in str(exc.value) + ) diff --git a/tests/functional/simple_snapshot/test_snapshot_config.py b/tests/functional/simple_snapshot/test_snapshot_config.py new file mode 100644 index 00000000000..5124cf9c38b --- /dev/null +++ b/tests/functional/simple_snapshot/test_snapshot_config.py @@ -0,0 +1,67 @@ +import pytest + +from dbt.tests.util import run_dbt, write_file + +orders_sql = """ +select 1 as id, 101 as user_id, 'pending' as status +""" + +snapshot_sql = """ +{% snapshot orders_snapshot %} + +{{ + config( + target_schema=schema, + strategy='check', + unique_key='id', + check_cols=['status'], + ) +}} + +select * from {{ ref('orders') }} + +{% endsnapshot %} +""" + +snapshot_no_config_sql = """ +{% snapshot orders_snapshot %} + +select * from {{ ref('orders') }} + +{% endsnapshot %} +""" + +snapshot_schema_yml = """ +snapshots: + - name: orders_snapshot + config: + target_schema: test + strategy: check + unique_key: id + check_cols: ['status'] +""" + + +class TestSnapshotConfig: + @pytest.fixture(scope="class") + def models(self): + return {"orders.sql": orders_sql} + + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot_orders.sql": snapshot_sql} + + def test_config(self, project): + run_dbt(["run"]) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + # try to parse with config in schema file + write_file( + snapshot_no_config_sql, project.project_root, "snapshots", "snapshot_orders.sql" + ) + write_file(snapshot_schema_yml, project.project_root, "snapshots", "snapshot.yml") + results = run_dbt(["parse"]) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 diff --git a/tests/unit/contracts/graph/test_nodes_parsed.py b/tests/unit/contracts/graph/test_nodes_parsed.py index 4e8392cb3df..bce860e63d3 100644 --- a/tests/unit/contracts/graph/test_nodes_parsed.py +++ b/tests/unit/contracts/graph/test_nodes_parsed.py @@ -1333,7 +1333,7 @@ def test_invalid_missing_updated_at(basic_timestamp_snapshot_config_dict): bad_fields = basic_timestamp_snapshot_config_dict del bad_fields["updated_at"] bad_fields["check_cols"] = "all" - assert_fails_validation(bad_fields, SnapshotConfig) + assert_snapshot_config_fails_validation(bad_fields) @pytest.fixture @@ -1437,7 +1437,7 @@ def test_complex_snapshot_config( def test_invalid_check_wrong_strategy(basic_check_snapshot_config_dict): wrong_strategy = basic_check_snapshot_config_dict wrong_strategy["strategy"] = "timestamp" - assert_fails_validation(wrong_strategy, SnapshotConfig) + assert_snapshot_config_fails_validation(wrong_strategy) def test_invalid_missing_check_cols(basic_check_snapshot_config_dict): @@ -1445,6 +1445,8 @@ def test_invalid_missing_check_cols(basic_check_snapshot_config_dict): del wrong_fields["check_cols"] with pytest.raises(ValidationError, match=r"A snapshot configured with the check strategy"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() def test_missing_snapshot_configs(basic_check_snapshot_config_dict): @@ -1452,22 +1454,35 @@ def test_missing_snapshot_configs(basic_check_snapshot_config_dict): del wrong_fields["strategy"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() wrong_fields["strategy"] = "timestamp" del wrong_fields["unique_key"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() wrong_fields["unique_key"] = "id" del wrong_fields["target_schema"] with pytest.raises(ValidationError, match=r"Snapshots must be configured with a 'strategy'"): SnapshotConfig.validate(wrong_fields) + cfg = SnapshotConfig.from_dict(wrong_fields) + cfg.final_validate() + + +def assert_snapshot_config_fails_validation(dct): + with pytest.raises(ValidationError): + SnapshotConfig.validate(dct) + obj = SnapshotConfig.from_dict(dct) + obj.final_validate() def test_invalid_check_value(basic_check_snapshot_config_dict): invalid_check_type = basic_check_snapshot_config_dict invalid_check_type["check_cols"] = "some" - assert_fails_validation(invalid_check_type, SnapshotConfig) + assert_snapshot_config_fails_validation(invalid_check_type) @pytest.fixture