diff --git a/CHANGELOG.md b/CHANGELOG.md index 76fb4d9c6..34499cc47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,9 @@ -## dbt-databricks 1.7.x (TBD) +## dbt-databricks 1.7.2 (TBD) ### Features - Adding capability to specify compute on a per model basis ([488](https://github.com/databricks/dbt-databricks/pull/488)) +- Selectively persist column docs that have changed between runs of incremental ([513](https://github.com/databricks/dbt-databricks/pull/513)) ## dbt-databricks 1.7.1 (Nov 13, 2023) diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index a3e088335..e717debd3 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,11 +1,13 @@ from dataclasses import dataclass -from typing import ClassVar, Dict +from typing import ClassVar, Dict, Optional from dbt.adapters.spark.column import SparkColumn @dataclass class DatabricksColumn(SparkColumn): + comment: Optional[str] = None + TYPE_LABELS: ClassVar[Dict[str, str]] = { "LONG": "BIGINT", } diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 2435f5e57..15ab317da 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -391,6 +391,7 @@ def parse_describe_extended( # type: ignore[override] column=column["col_name"], column_index=idx, dtype=column["data_type"], + comment=column["comment"], ) for idx, column in enumerate(rows) ] @@ -657,3 +658,24 @@ def _catalog(self, catalog: Optional[str]) -> Iterator[None]: finally: if current_catalog is not None: self.execute_macro(USE_CATALOG_MACRO_NAME, kwargs=dict(catalog=current_catalog)) + + @available.parse(lambda *a, **k: {}) + def get_persist_doc_columns( + self, existing_columns: List[DatabricksColumn], columns: Dict[str, Any] + ) -> Dict[str, Any]: + """Returns a dictionary of columns that have updated comments.""" + return_columns = {} + + # Since existing_columns are gathered after writing the table, we don't need to include any + # columns from the model that are not in the existing_columns. If we did, it would lead to + # an error when we tried to alter the table. + for column in existing_columns: + name = column.column + if ( + name in columns + and "description" in columns[name] + and columns[name]["description"] != (column.comment or "") + ): + return_columns[name] = columns[name] + + return return_columns diff --git a/dbt/include/databricks/macros/adapters.sql b/dbt/include/databricks/macros/adapters.sql index c396db626..4010a433b 100644 --- a/dbt/include/databricks/macros/adapters.sql +++ b/dbt/include/databricks/macros/adapters.sql @@ -505,3 +505,11 @@ ) -%} {% do return([false, new_relation]) %} {% endmacro %} + +{% macro databricks__persist_docs(relation, model, for_relation, for_columns) -%} + {% if for_columns and config.persist_column_docs() and model.columns %} + {%- set existing_columns = adapter.get_columns_in_relation(relation) -%} + {%- set columns_to_persist_docs = adapter.get_persist_doc_columns(existing_columns, model.columns) -%} + {% do alter_column_comment(relation, columns_to_persist_docs) %} + {% endif %} +{% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/materializations/incremental/incremental.sql b/dbt/include/databricks/macros/materializations/incremental/incremental.sql index d0c63d174..af8777509 100644 --- a/dbt/include/databricks/macros/materializations/incremental/incremental.sql +++ b/dbt/include/databricks/macros/materializations/incremental/incremental.sql @@ -84,8 +84,6 @@ {% do apply_grants(target_relation, grant_config, should_revoke) %} {% do persist_docs(target_relation, model) %} - - {% do optimize(target_relation) %} {{ run_hooks(post_hooks) }} diff --git a/tests/functional/adapter/incremental/fixtures.py b/tests/functional/adapter/incremental/fixtures.py new file mode 100644 index 000000000..4afb5434f --- /dev/null +++ b/tests/functional/adapter/incremental/fixtures.py @@ -0,0 +1,46 @@ +merge_update_columns_sql = """ +{{ config( + materialized = 'incremental', + unique_key = 'id', + merge_update_columns = ['msg'], +) }} + +{% if not is_incremental() %} + +select cast(1 as bigint) as id, 'hello' as msg, 'blue' as color +union all +select cast(2 as bigint) as id, 'goodbye' as msg, 'red' as color + +{% else %} + +-- msg will be updated, color will be ignored +select cast(2 as bigint) as id, 'yo' as msg, 'green' as color +union all +select cast(3 as bigint) as id, 'anyway' as msg, 'purple' as color + +{% endif %} +""" + +no_comment_schema = """ +version: 2 + +models: + - name: merge_update_columns_sql + columns: + - name: id + - name: msg + - name: color +""" + +comment_schema = """ +version: 2 + +models: + - name: merge_update_columns_sql + columns: + - name: id + description: This is the id column + - name: msg + description: This is the msg column + - name: color +""" diff --git a/tests/functional/adapter/incremental/test_incremental_persist_docs.py b/tests/functional/adapter/incremental/test_incremental_persist_docs.py new file mode 100644 index 000000000..466938d33 --- /dev/null +++ b/tests/functional/adapter/incremental/test_incremental_persist_docs.py @@ -0,0 +1,34 @@ +from dbt.tests import util +from tests.functional.adapter.incremental import fixtures + +import pytest + + +class TestIncrementalPersistDocs: + @pytest.fixture(scope="class") + def models(self): + return { + "merge_update_columns_sql.sql": fixtures.merge_update_columns_sql, + "schema.yml": fixtures.no_comment_schema, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "test": { + "+persist_docs": { + "relation": True, + "columns": True, + }, + } + } + } + + def test_adding_comments(self, project): + util.run_dbt(["run"]) + util.write_file(fixtures.comment_schema, "models", "schema.yml") + _, out = util.run_dbt_and_capture(["--debug", "run"]) + assert "comment 'This is the id column'" in out + assert "comment 'This is the msg column'" in out + assert "comment ''" not in out diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index cb1f11545..5cef63a2c 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -9,6 +9,7 @@ from dbt.adapters.databricks import __version__ from dbt.adapters.databricks import DatabricksAdapter, DatabricksRelation +from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.impl import check_not_found_error from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.connections import ( @@ -18,8 +19,10 @@ ) from tests.unit.utils import config_from_parts_or_dicts +import pytest -class TestDatabricksAdapter(unittest.TestCase): + +class DatabricksAdapterBase: def setUp(self): flags.STRICT_MODE = False @@ -64,6 +67,8 @@ def _get_config( return config_from_parts_or_dicts(self.project_cfg, self.profile_cfg) + +class TestDatabricksAdapter(DatabricksAdapterBase, unittest.TestCase): def test_two_catalog_settings(self): with self.assertRaisesRegex( dbt.exceptions.DbtProfileError, @@ -362,35 +367,29 @@ def test_parse_relation(self): # Mimics the output of Spark with a DESCRIBE TABLE EXTENDED plain_rows = [ - ("col1", "decimal(22,0)"), - ( - "col2", - "string", - ), - ("dt", "date"), - ("struct_col", "struct"), - ("# Partition Information", "data_type"), - ("# col_name", "data_type"), - ("dt", "date"), - (None, None), + ("col1", "decimal(22,0)", "comment"), + ("col2", "string", "comment"), + ("dt", "date", None), + ("struct_col", "struct", None), + ("# Partition Information", "data_type", None), + ("# col_name", "data_type", "comment"), + ("dt", "date", None), + (None, None, None), ("# Detailed Table Information", None), ("Database", None), - ("Owner", "root"), - ("Created Time", "Wed Feb 04 18:15:00 UTC 1815"), - ("Last Access", "Wed May 20 19:25:00 UTC 1925"), - ("Type", "MANAGED"), - ("Provider", "delta"), - ("Location", "/mnt/vo"), - ("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - ("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat"), - ( - "OutputFormat", - "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", - ), - ("Partition Provider", "Catalog"), + ("Owner", "root", None), + ("Created Time", "Wed Feb 04 18:15:00 UTC 1815", None), + ("Last Access", "Wed May 20 19:25:00 UTC 1925", None), + ("Type", "MANAGED", None), + ("Provider", "delta", None), + ("Location", "/mnt/vo", None), + ("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", None), + ("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat", None), + ("OutputFormat", "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", None), + ("Partition Provider", "Catalog", None), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] + input_cols = [Row(keys=["col_name", "data_type", "comment"], values=r) for r in plain_rows] config = self._get_config() metadata, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) @@ -431,6 +430,7 @@ def test_parse_relation(self): "numeric_scale": None, "numeric_precision": None, "char_size": None, + "comment": "comment", }, ) @@ -448,6 +448,7 @@ def test_parse_relation(self): "numeric_scale": None, "numeric_precision": None, "char_size": None, + "comment": "comment", }, ) @@ -465,6 +466,7 @@ def test_parse_relation(self): "numeric_scale": None, "numeric_precision": None, "char_size": None, + "comment": None, }, ) @@ -482,6 +484,7 @@ def test_parse_relation(self): "numeric_scale": None, "numeric_precision": None, "char_size": None, + "comment": None, }, ) @@ -496,12 +499,12 @@ def test_parse_relation_with_integer_owner(self): # Mimics the output of Spark with a DESCRIBE TABLE EXTENDED plain_rows = [ - ("col1", "decimal(22,0)"), - ("# Detailed Table Information", None), - ("Owner", 1234), + ("col1", "decimal(22,0)", "comment"), + ("# Detailed Table Information", None, None), + ("Owner", 1234, None), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] + input_cols = [Row(keys=["col_name", "data_type", "comment"], values=r) for r in plain_rows] config = self._get_config() _, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) @@ -519,28 +522,25 @@ def test_parse_relation_with_statistics(self): # Mimics the output of Spark with a DESCRIBE TABLE EXTENDED plain_rows = [ - ("col1", "decimal(22,0)"), - ("# Partition Information", "data_type"), - (None, None), - ("# Detailed Table Information", None), - ("Database", None), - ("Owner", "root"), - ("Created Time", "Wed Feb 04 18:15:00 UTC 1815"), - ("Last Access", "Wed May 20 19:25:00 UTC 1925"), - ("Statistics", "1109049927 bytes, 14093476 rows"), - ("Type", "MANAGED"), - ("Provider", "delta"), - ("Location", "/mnt/vo"), - ("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - ("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat"), - ( - "OutputFormat", - "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", - ), - ("Partition Provider", "Catalog"), + ("col1", "decimal(22,0)", "comment"), + ("# Partition Information", "data_type", None), + (None, None, None), + ("# Detailed Table Information", None, None), + ("Database", None, None), + ("Owner", "root", None), + ("Created Time", "Wed Feb 04 18:15:00 UTC 1815", None), + ("Last Access", "Wed May 20 19:25:00 UTC 1925", None), + ("Statistics", "1109049927 bytes, 14093476 rows", None), + ("Type", "MANAGED", None), + ("Provider", "delta", None), + ("Location", "/mnt/vo", None), + ("Serde Library", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", None), + ("InputFormat", "org.apache.hadoop.mapred.SequenceFileInputFormat", None), + ("OutputFormat", "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat", None), + ("Partition Provider", "Catalog", None), ] - input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] + input_cols = [Row(keys=["col_name", "data_type", "comment"], values=r) for r in plain_rows] config = self._get_config() metadata, rows = DatabricksAdapter(config).parse_describe_extended(relation, input_cols) @@ -576,6 +576,7 @@ def test_parse_relation_with_statistics(self): "table_owner": "root", "column": "col1", "column_index": 0, + "comment": "comment", "dtype": "decimal(22,0)", "numeric_scale": None, "numeric_precision": None, @@ -652,6 +653,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "stats:bytes:include": True, "stats:bytes:label": "bytes", "stats:bytes:value": 123456789, + "comment": None, }, ) @@ -666,6 +668,7 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) "column": "struct_col", "column_index": 3, "dtype": "struct", + "comment": None, "numeric_scale": None, "numeric_precision": None, "char_size": None, @@ -729,6 +732,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_owner": "root", "column": "col2", "column_index": 1, + "comment": None, "dtype": "string", "numeric_scale": None, "numeric_precision": None, @@ -746,6 +750,7 @@ def test_parse_columns_from_information_with_view_type(self): "table_owner": "root", "column": "struct_col", "column_index": 3, + "comment": None, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -795,6 +800,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_owner": "root", "column": "dt", "column_index": 2, + "comment": None, "dtype": "date", "numeric_scale": None, "numeric_precision": None, @@ -820,6 +826,7 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel "table_owner": "root", "column": "struct_col", "column_index": 3, + "comment": None, "dtype": "struct", "numeric_scale": None, "numeric_precision": None, @@ -916,3 +923,46 @@ def test_unexpected_error(self): self.assertFalse(check_not_found_error("[DATABASE_NOT_FOUND]")) self.assertFalse(check_not_found_error("Schema foo not found")) self.assertFalse(check_not_found_error("Database 'foo' not there")) + + +class TestGetPersistDocColumns(DatabricksAdapterBase): + @pytest.fixture(scope="class") + def adapter(self) -> DatabricksAdapter: + self.setUp() + return DatabricksAdapter(self._get_config()) + + def create_column(self, name, comment) -> DatabricksColumn: + return DatabricksColumn( + column=name, + dtype="string", + comment=comment, + ) + + def test_get_persist_doc_columns_empty(self, adapter): + assert adapter.get_persist_doc_columns([], {}) == {} + + def test_get_persist_doc_columns_no_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col2": {"name": "col2", "description": "comment2"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == {} + + def test_get_persist_doc_columns_full_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col1": {"name": "col1", "description": "comment1"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == {} + + def test_get_persist_doc_columns_partial_match(self, adapter): + existing = [self.create_column("col1", "comment1")] + column_dict = {"col1": {"name": "col1", "description": "comment2"}} + assert adapter.get_persist_doc_columns(existing, column_dict) == column_dict + + def test_get_persist_doc_columns_mixed(self, adapter): + existing = [self.create_column("col1", "comment1"), self.create_column("col2", "comment2")] + column_dict = { + "col1": {"name": "col1", "description": "comment2"}, + "col2": {"name": "col2", "description": "comment2"}, + } + expected = { + "col1": {"name": "col1", "description": "comment2"}, + } + assert adapter.get_persist_doc_columns(existing, column_dict) == expected