Skip to content

Commit

Permalink
case-insensitive comparisons in unit testing, base unit testing test (#…
Browse files Browse the repository at this point in the history
…55)

Co-authored-by: Mike Alfare <[email protected]>
  • Loading branch information
MichelleArk and mikealfare authored Feb 9, 2024
1 parent cb08aae commit c30497e
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dbt/include/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)
__path__ = extend_path(__path__, __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@

{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual AS (
with dbt_internal_unit_test_actual as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected AS (
dbt_internal_unit_test_expected as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}

{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
Expand Down
56 changes: 36 additions & 20 deletions dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
{% set default_row = {} %}

{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(defer_relation or this) -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{#-- This needs to be a case-insensitive comparison --#}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}
{%- endif -%}

Expand All @@ -18,12 +19,13 @@
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}


{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set formatted_row = format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
{%- do default_row_copy.update(formatted_row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -32,7 +34,7 @@ union all

{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- for column_name, column_value in default_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
Expand All @@ -46,9 +48,9 @@ union all
limit 0
{%- else -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set formatted_row = format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in formatted_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -59,18 +61,32 @@ union all
{% endmacro %}

{%- macro format_row(row, column_name_to_data_types) -%}
{#-- generate case-insensitive formatted row --#}
{% set formatted_row = {} %}
{%- for column_name, column_value in row.items() -%}
{% set column_name = column_name|lower %}

{#-- wrap yaml strings in quotes, apply cast --#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}
{%- if column_name not in column_name_to_data_types %}
{#-- if user-provided row contains column name that relation does not contain, raise an error --#}
{% set fixture_name = "expected output" if model.resource_type == 'unit_test' else ("'" ~ model.name ~ "'") %}
{{ exceptions.raise_compiler_error(
"Invalid column name: '" ~ column_name ~ "' in unit test fixture for " ~ fixture_name ~ "."
"\nAccepted columns for " ~ fixture_name ~ " are: " ~ (column_name_to_data_types.keys()|list)
) }}
{%- endif -%}

{%- set column_type = column_name_to_data_types[column_name] %}

{#-- sanitize column_value: wrap yaml strings in quotes, apply cast --#}
{%- set column_value_clean = column_value -%}
{%- if column_value is string -%}
{%- set column_value_clean = dbt.string_literal(dbt.escape_single_quotes(column_value)) -%}
{%- elif column_value is none -%}
{%- set column_value_clean = 'null' -%}
{%- endif -%}

{%- set row_update = {column_name: safe_cast(column_value_clean, column_type) } -%}
{%- do formatted_row.update(row_update) -%}
{%- endfor -%}
{{ return(formatted_row) }}
{%- endmacro -%}
7 changes: 7 additions & 0 deletions dbt/include/global_project/macros/utils/cast.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% macro cast(field, type) %}
{{ return(adapter.dispatch('cast', 'dbt') (field, type)) }}
{% endmacro %}

{% macro default__cast(field, type) %}
cast({{field}} as {{type}})
{% endmacro %}
49 changes: 49 additions & 0 deletions dbt/tests/adapter/unit_testing/test_case_insensitivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from dbt.tests.util import run_dbt


my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select 1 as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {tested_column: 1}
- {TESTED_COLUMN: 2}
- {tested_colUmn: 3}
expect:
rows:
- {tested_column: 1}
- {TESTED_COLUMN: 2}
- {tested_colUmn: 3}
"""


class BaseUnitTestCaseInsensivity:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"unit_tests.yml": test_my_model_yml,
}

def test_case_insensitivity(self, project):
results = run_dbt(["run"])
assert len(results) == 2

results = run_dbt(["test"])


class TestPosgresUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity):
pass
62 changes: 62 additions & 0 deletions dbt/tests/adapter/unit_testing/test_invalid_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from dbt.tests.util import run_dbt, run_dbt_and_capture


my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select 1 as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_invalid_input_column_name
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {invalid_column_name: 1}
expect:
rows:
- {tested_column: 1}
- name: test_invalid_expect_column_name
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {tested_column: 1}
expect:
rows:
- {invalid_column_name: 1}
"""


class BaseUnitTestInvalidInput:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"unit_tests.yml": test_my_model_yml,
}

def test_invalid_input(self, project):
results = run_dbt(["run"])
assert len(results) == 2

_, out = run_dbt_and_capture(
["test", "--select", "test_name:test_invalid_input_column_name"], expect_pass=False
)
assert "Invalid column name: 'invalid_column_name' in unit test fixture for 'my_upstream_model'." in out

_, out = run_dbt_and_capture(
["test", "--select", "test_name:test_invalid_expect_column_name"], expect_pass=False
)
assert "Invalid column name: 'invalid_column_name' in unit test fixture for expected output." in out


class TestPostgresUnitTestInvalidInput(BaseUnitTestInvalidInput):
pass
84 changes: 84 additions & 0 deletions dbt/tests/adapter/unit_testing/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

from dbt.tests.util import write_file, run_dbt

my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select
{sql_value} as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {{ tested_column: {yaml_value} }}
expect:
rows:
- {{ tested_column: {yaml_value} }}
"""


class BaseUnitTestingTypes:
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
return [
["1", "1"],
["'1'", "1"],
["true", "true"],
["DATE '2020-01-02'", "2020-01-02"],
["TIMESTAMP '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["TIMESTAMPTZ '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["'1'::numeric", "1"],
[
"""'{"bar": "baz", "balance": 7.77, "active": false}'::json""",
"""'{"bar": "baz", "balance": 7.77, "active": false}'""",
],
# TODO: support complex types
# ["ARRAY['a','b','c']", """'{"a", "b", "c"}'"""],
# ["ARRAY[1,2,3]", """'{1, 2, 3}'"""],
]

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"schema.yml": test_my_model_yml,
}

def test_unit_test_data_type(self, project, data_types):
for sql_value, yaml_value in data_types:
# Write parametrized type value to sql files
write_file(
my_upstream_model_sql.format(sql_value=sql_value),
"models",
"my_upstream_model.sql",
)

# Write parametrized type value to unit test yaml definition
write_file(
test_my_model_yml.format(yaml_value=yaml_value),
"models",
"schema.yml",
)

results = run_dbt(["run", "--select", "my_upstream_model"])
assert len(results) == 1

try:
run_dbt(["test", "--select", "my_model"])
except Exception:
raise AssertionError(f"unit test failed when testing model with {sql_value}")


class TestPostgresUnitTestingTypes(BaseUnitTestingTypes):
pass

0 comments on commit c30497e

Please sign in to comment.