From c30497ebc15d79c27af68864c301ee1d9aea12be Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 8 Feb 2024 20:35:37 -0500 Subject: [PATCH] case-insensitive comparisons in unit testing, base unit testing test (#55) Co-authored-by: Mike Alfare <13974384+mikealfare@users.noreply.github.com> --- dbt/include/__init__.py | 2 +- .../macros/materializations/tests/helpers.sql | 8 +- .../macros/materializations/tests/unit.sql | 2 +- .../macros/unit_test_sql/get_fixture_sql.sql | 56 ++++++++----- .../global_project/macros/utils/cast.sql | 7 ++ .../unit_testing/test_case_insensitivity.py | 49 +++++++++++ .../unit_testing/test_invalid_input.py | 62 ++++++++++++++ dbt/tests/adapter/unit_testing/test_types.py | 84 +++++++++++++++++++ 8 files changed, 244 insertions(+), 26 deletions(-) create mode 100644 dbt/include/global_project/macros/utils/cast.sql create mode 100644 dbt/tests/adapter/unit_testing/test_case_insensitivity.py create mode 100644 dbt/tests/adapter/unit_testing/test_invalid_input.py create mode 100644 dbt/tests/adapter/unit_testing/test_types.py diff --git a/dbt/include/__init__.py b/dbt/include/__init__.py index 9088ea6a..b36383a6 100644 --- a/dbt/include/__init__.py +++ b/dbt/include/__init__.py @@ -1,3 +1,3 @@ from pkgutil import extend_path -__path__ = extend_path(__path__, __name__) \ No newline at end of file +__path__ = extend_path(__path__, __name__) diff --git a/dbt/include/global_project/macros/materializations/tests/helpers.sql b/dbt/include/global_project/macros/materializations/tests/helpers.sql index 13e640c2..ead727d9 100644 --- a/dbt/include/global_project/macros/materializations/tests/helpers.sql +++ b/dbt/include/global_project/macros/materializations/tests/helpers.sql @@ -22,17 +22,17 @@ {% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%} -- Build actual result given inputs -with dbt_internal_unit_test_actual AS ( +with dbt_internal_unit_test_actual as ( select - {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }} from ( {{ main_sql }} ) _dbt_internal_unit_test_actual ), -- Build expected result -dbt_internal_unit_test_expected AS ( +dbt_internal_unit_test_expected as ( select - {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }} from ( {{ expected_fixture_sql }} ) _dbt_internal_unit_test_expected diff --git a/dbt/include/global_project/macros/materializations/tests/unit.sql b/dbt/include/global_project/macros/materializations/tests/unit.sql index 79d5631b..6d7b632c 100644 --- a/dbt/include/global_project/macros/materializations/tests/unit.sql +++ b/dbt/include/global_project/macros/materializations/tests/unit.sql @@ -11,7 +11,7 @@ {%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%} {%- set column_name_to_data_types = {} -%} {%- for column in columns_in_relation -%} - {%- do column_name_to_data_types.update({column.name: column.dtype}) -%} + {%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%} {%- endfor -%} {% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %} diff --git a/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql b/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql index 5c4c5005..ba35fab2 100644 --- a/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql +++ b/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql @@ -3,10 +3,11 @@ {% set default_row = {} %} {%- if not column_name_to_data_types -%} -{%- set columns_in_relation = adapter.get_columns_in_relation(defer_relation or this) -%} +{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%} {%- set column_name_to_data_types = {} -%} {%- for column in columns_in_relation -%} -{%- do column_name_to_data_types.update({column.name: column.dtype}) -%} +{#-- This needs to be a case-insensitive comparison --#} +{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%} {%- endfor -%} {%- endif -%} @@ -18,12 +19,13 @@ {%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%} {%- endfor -%} + {%- for row in rows -%} -{%- do format_row(row, column_name_to_data_types) -%} +{%- set formatted_row = format_row(row, column_name_to_data_types) -%} {%- set default_row_copy = default_row.copy() -%} -{%- do default_row_copy.update(row) -%} +{%- do default_row_copy.update(formatted_row) -%} select -{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %} +{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %} {%- endfor %} {%- if not loop.last %} union all @@ -32,7 +34,7 @@ union all {%- if (rows | length) == 0 -%} select - {%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %} + {%- for column_name, column_value in default_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%},{%- endif %} {%- endfor %} limit 0 {%- endif -%} @@ -46,9 +48,9 @@ union all limit 0 {%- else -%} {%- for row in rows -%} -{%- do format_row(row, column_name_to_data_types) -%} +{%- set formatted_row = format_row(row, column_name_to_data_types) -%} select -{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %} +{%- for column_name, column_value in formatted_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %} {%- endfor %} {%- if not loop.last %} union all @@ -59,18 +61,32 @@ union all {% endmacro %} {%- macro format_row(row, column_name_to_data_types) -%} + {#-- generate case-insensitive formatted row --#} + {% set formatted_row = {} %} + {%- for column_name, column_value in row.items() -%} + {% set column_name = column_name|lower %} -{#-- wrap yaml strings in quotes, apply cast --#} -{%- for column_name, column_value in row.items() -%} -{% set row_update = {column_name: column_value} %} -{%- if column_value is string -%} -{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%} -{%- elif column_value is none -%} -{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%} -{%- else -%} -{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%} -{%- endif -%} -{%- do row.update(row_update) -%} -{%- endfor -%} + {%- if column_name not in column_name_to_data_types %} + {#-- if user-provided row contains column name that relation does not contain, raise an error --#} + {% set fixture_name = "expected output" if model.resource_type == 'unit_test' else ("'" ~ model.name ~ "'") %} + {{ exceptions.raise_compiler_error( + "Invalid column name: '" ~ column_name ~ "' in unit test fixture for " ~ fixture_name ~ "." + "\nAccepted columns for " ~ fixture_name ~ " are: " ~ (column_name_to_data_types.keys()|list) + ) }} + {%- endif -%} + + {%- set column_type = column_name_to_data_types[column_name] %} + + {#-- sanitize column_value: wrap yaml strings in quotes, apply cast --#} + {%- set column_value_clean = column_value -%} + {%- if column_value is string -%} + {%- set column_value_clean = dbt.string_literal(dbt.escape_single_quotes(column_value)) -%} + {%- elif column_value is none -%} + {%- set column_value_clean = 'null' -%} + {%- endif -%} + {%- set row_update = {column_name: safe_cast(column_value_clean, column_type) } -%} + {%- do formatted_row.update(row_update) -%} + {%- endfor -%} + {{ return(formatted_row) }} {%- endmacro -%} diff --git a/dbt/include/global_project/macros/utils/cast.sql b/dbt/include/global_project/macros/utils/cast.sql new file mode 100644 index 00000000..ea5b1aac --- /dev/null +++ b/dbt/include/global_project/macros/utils/cast.sql @@ -0,0 +1,7 @@ +{% macro cast(field, type) %} + {{ return(adapter.dispatch('cast', 'dbt') (field, type)) }} +{% endmacro %} + +{% macro default__cast(field, type) %} + cast({{field}} as {{type}}) +{% endmacro %} diff --git a/dbt/tests/adapter/unit_testing/test_case_insensitivity.py b/dbt/tests/adapter/unit_testing/test_case_insensitivity.py new file mode 100644 index 00000000..f6f89766 --- /dev/null +++ b/dbt/tests/adapter/unit_testing/test_case_insensitivity.py @@ -0,0 +1,49 @@ +import pytest +from dbt.tests.util import run_dbt + + +my_model_sql = """ +select + tested_column from {{ ref('my_upstream_model')}} +""" + +my_upstream_model_sql = """ +select 1 as tested_column +""" + +test_my_model_yml = """ +unit_tests: + - name: test_my_model + model: my_model + given: + - input: ref('my_upstream_model') + rows: + - {tested_column: 1} + - {TESTED_COLUMN: 2} + - {tested_colUmn: 3} + expect: + rows: + - {tested_column: 1} + - {TESTED_COLUMN: 2} + - {tested_colUmn: 3} +""" + + +class BaseUnitTestCaseInsensivity: + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_sql, + "my_upstream_model.sql": my_upstream_model_sql, + "unit_tests.yml": test_my_model_yml, + } + + def test_case_insensitivity(self, project): + results = run_dbt(["run"]) + assert len(results) == 2 + + results = run_dbt(["test"]) + + +class TestPosgresUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass diff --git a/dbt/tests/adapter/unit_testing/test_invalid_input.py b/dbt/tests/adapter/unit_testing/test_invalid_input.py new file mode 100644 index 00000000..6c41ceb9 --- /dev/null +++ b/dbt/tests/adapter/unit_testing/test_invalid_input.py @@ -0,0 +1,62 @@ +import pytest +from dbt.tests.util import run_dbt, run_dbt_and_capture + + +my_model_sql = """ +select + tested_column from {{ ref('my_upstream_model')}} +""" + +my_upstream_model_sql = """ +select 1 as tested_column +""" + +test_my_model_yml = """ +unit_tests: + - name: test_invalid_input_column_name + model: my_model + given: + - input: ref('my_upstream_model') + rows: + - {invalid_column_name: 1} + expect: + rows: + - {tested_column: 1} + - name: test_invalid_expect_column_name + model: my_model + given: + - input: ref('my_upstream_model') + rows: + - {tested_column: 1} + expect: + rows: + - {invalid_column_name: 1} +""" + + +class BaseUnitTestInvalidInput: + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_sql, + "my_upstream_model.sql": my_upstream_model_sql, + "unit_tests.yml": test_my_model_yml, + } + + def test_invalid_input(self, project): + results = run_dbt(["run"]) + assert len(results) == 2 + + _, out = run_dbt_and_capture( + ["test", "--select", "test_name:test_invalid_input_column_name"], expect_pass=False + ) + assert "Invalid column name: 'invalid_column_name' in unit test fixture for 'my_upstream_model'." in out + + _, out = run_dbt_and_capture( + ["test", "--select", "test_name:test_invalid_expect_column_name"], expect_pass=False + ) + assert "Invalid column name: 'invalid_column_name' in unit test fixture for expected output." in out + + +class TestPostgresUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass diff --git a/dbt/tests/adapter/unit_testing/test_types.py b/dbt/tests/adapter/unit_testing/test_types.py new file mode 100644 index 00000000..1d19aafb --- /dev/null +++ b/dbt/tests/adapter/unit_testing/test_types.py @@ -0,0 +1,84 @@ +import pytest + +from dbt.tests.util import write_file, run_dbt + +my_model_sql = """ +select + tested_column from {{ ref('my_upstream_model')}} +""" + +my_upstream_model_sql = """ +select + {sql_value} as tested_column +""" + +test_my_model_yml = """ +unit_tests: + - name: test_my_model + model: my_model + given: + - input: ref('my_upstream_model') + rows: + - {{ tested_column: {yaml_value} }} + expect: + rows: + - {{ tested_column: {yaml_value} }} +""" + + +class BaseUnitTestingTypes: + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["'1'", "1"], + ["true", "true"], + ["DATE '2020-01-02'", "2020-01-02"], + ["TIMESTAMP '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"], + ["TIMESTAMPTZ '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"], + ["'1'::numeric", "1"], + [ + """'{"bar": "baz", "balance": 7.77, "active": false}'::json""", + """'{"bar": "baz", "balance": 7.77, "active": false}'""", + ], + # TODO: support complex types + # ["ARRAY['a','b','c']", """'{"a", "b", "c"}'"""], + # ["ARRAY[1,2,3]", """'{1, 2, 3}'"""], + ] + + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_sql, + "my_upstream_model.sql": my_upstream_model_sql, + "schema.yml": test_my_model_yml, + } + + def test_unit_test_data_type(self, project, data_types): + for sql_value, yaml_value in data_types: + # Write parametrized type value to sql files + write_file( + my_upstream_model_sql.format(sql_value=sql_value), + "models", + "my_upstream_model.sql", + ) + + # Write parametrized type value to unit test yaml definition + write_file( + test_my_model_yml.format(yaml_value=yaml_value), + "models", + "schema.yml", + ) + + results = run_dbt(["run", "--select", "my_upstream_model"]) + assert len(results) == 1 + + try: + run_dbt(["test", "--select", "my_model"]) + except Exception: + raise AssertionError(f"unit test failed when testing model with {sql_value}") + + +class TestPostgresUnitTestingTypes(BaseUnitTestingTypes): + pass