Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case-insensitive comparisons in unit testing, base unit testing test #55

Merged
merged 14 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbt/include/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)
__path__ = extend_path(__path__, __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@

{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual AS (
with dbt_internal_unit_test_actual as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected AS (
dbt_internal_unit_test_expected as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}

{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
{% set default_row = {} %}

{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(defer_relation or this) -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{#-- This needs to be a case-insensitive comparison --#}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}
{%- endif -%}

Expand All @@ -18,12 +19,13 @@
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}


{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set formatted_row = format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
{%- do default_row_copy.update(formatted_row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -32,7 +34,7 @@ union all

{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- for column_name, column_value in default_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
Expand All @@ -46,9 +48,9 @@ union all
limit 0
{%- else -%}
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set formatted_row = format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in formatted_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -59,18 +61,32 @@ union all
{% endmacro %}

{%- macro format_row(row, column_name_to_data_types) -%}
{#-- generate case-insensitive formatted row --#}
{% set formatted_row = {} %}
{%- for column_name, column_value in row.items() -%}
{% set column_name = column_name|lower %}

{#-- wrap yaml strings in quotes, apply cast --#}
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%}
{%- endif -%}
{%- do row.update(row_update) -%}
{%- endfor -%}
{%- if column_name not in column_name_to_data_types %}
{#-- if user-provided row contains column name that relation does not contain, raise an error --#}
{% set fixture_name = "expected output" if model.resource_type == 'unit_test' else ("'" ~ model.name ~ "'") %}
{{ exceptions.raise_compiler_error(
"Invalid column name: '" ~ column_name ~ "' in unit test fixture for " ~ fixture_name ~ "."
"\nAccepted columns for " ~ fixture_name ~ " are: " ~ (column_name_to_data_types.keys()|list)
) }}
{%- endif -%}

{%- set column_type = column_name_to_data_types[column_name] %}

{#-- sanitize column_value: wrap yaml strings in quotes, apply cast --#}
{%- set column_value_clean = column_value -%}
{%- if column_value is string -%}
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
{%- set column_value_clean = dbt.string_literal(dbt.escape_single_quotes(column_value)) -%}
{%- elif column_value is none -%}
{%- set column_value_clean = 'null' -%}
{%- endif -%}

{%- set row_update = {column_name: safe_cast(column_value_clean, column_type) } -%}
{%- do formatted_row.update(row_update) -%}
{%- endfor -%}
{{ return(formatted_row) }}
{%- endmacro -%}
7 changes: 7 additions & 0 deletions dbt/include/global_project/macros/utils/cast.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% macro cast(field, type) %}
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
{{ return(adapter.dispatch('cast', 'dbt') (field, type)) }}
{% endmacro %}

{% macro default__cast(field, type) %}
cast({{field}} as {{type}})
{% endmacro %}
49 changes: 49 additions & 0 deletions dbt/tests/adapter/unit_testing/test_case_insensitivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from dbt.tests.util import run_dbt


my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select 1 as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {tested_column: 1}
- {TESTED_COLUMN: 2}
- {tested_colUmn: 3}
expect:
rows:
- {tested_column: 1}
- {TESTED_COLUMN: 2}
- {tested_colUmn: 3}
"""


class BaseUnitTestCaseInsensivity:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"unit_tests.yml": test_my_model_yml,
}

def test_case_insensitivity(self, project):
results = run_dbt(["run"])
assert len(results) == 2

results = run_dbt(["test"])


class TestPosgresUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity):
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
pass
62 changes: 62 additions & 0 deletions dbt/tests/adapter/unit_testing/test_invalid_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from dbt.tests.util import run_dbt, run_dbt_and_capture


my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select 1 as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_invalid_input_column_name
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {invalid_column_name: 1}
expect:
rows:
- {tested_column: 1}
- name: test_invalid_expect_column_name
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {tested_column: 1}
expect:
rows:
- {invalid_column_name: 1}
"""


class BaseUnitTestInvalidInput:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"unit_tests.yml": test_my_model_yml,
}

def test_invalid_input(self, project):
results = run_dbt(["run"])
assert len(results) == 2

_, out = run_dbt_and_capture(
["test", "--select", "test_name:test_invalid_input_column_name"], expect_pass=False
)
assert "Invalid column name: 'invalid_column_name' in unit test fixture for 'my_upstream_model'." in out

_, out = run_dbt_and_capture(
["test", "--select", "test_name:test_invalid_expect_column_name"], expect_pass=False
)
assert "Invalid column name: 'invalid_column_name' in unit test fixture for expected output." in out


class TestPostgresUnitTestInvalidInput(BaseUnitTestInvalidInput):
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
pass
84 changes: 84 additions & 0 deletions dbt/tests/adapter/unit_testing/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

from dbt.tests.util import write_file, run_dbt

my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select
{sql_value} as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {{ tested_column: {yaml_value} }}
expect:
rows:
- {{ tested_column: {yaml_value} }}
"""


class BaseUnitTestingTypes:
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
return [
["1", "1"],
["'1'", "1"],
["true", "true"],
["DATE '2020-01-02'", "2020-01-02"],
["TIMESTAMP '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["TIMESTAMPTZ '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["'1'::numeric", "1"],
[
"""'{"bar": "baz", "balance": 7.77, "active": false}'::json""",
"""'{"bar": "baz", "balance": 7.77, "active": false}'""",
],
# TODO: support complex types
# ["ARRAY['a','b','c']", """'{"a", "b", "c"}'"""],
# ["ARRAY[1,2,3]", """'{1, 2, 3}'"""],
]

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"schema.yml": test_my_model_yml,
}

def test_unit_test_data_type(self, project, data_types):
for sql_value, yaml_value in data_types:
# Write parametrized type value to sql files
write_file(
my_upstream_model_sql.format(sql_value=sql_value),
"models",
"my_upstream_model.sql",
)

# Write parametrized type value to unit test yaml definition
write_file(
test_my_model_yml.format(yaml_value=yaml_value),
"models",
"schema.yml",
)

results = run_dbt(["run", "--select", "my_upstream_model"])
assert len(results) == 1

try:
run_dbt(["test", "--select", "my_model"])
except Exception:
raise AssertionError(f"unit test failed when testing model with {sql_value}")


class TestPostgresUnitTestingTypes(BaseUnitTestingTypes):
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
pass
Loading