diff --git a/.changes/unreleased/Features-20240220-195925.yaml b/.changes/unreleased/Features-20240220-195925.yaml new file mode 100644 index 000000000..c5d86ab7c --- /dev/null +++ b/.changes/unreleased/Features-20240220-195925.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Implement spark__safe_cast and add functional tests for unit testing +time: 2024-02-20T19:59:25.907821-05:00 +custom: + Author: michelleark + Issue: "987" diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index bf9f63cf9..a6404a2de 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -387,6 +387,7 @@ "identifier": tmp_identifier }) -%} + {%- set tmp_relation = tmp_relation.include(database=false, schema=false) -%} {% do return(tmp_relation) %} {% endmacro %} diff --git a/dbt/include/spark/macros/utils/safe_cast.sql b/dbt/include/spark/macros/utils/safe_cast.sql new file mode 100644 index 000000000..3ce5820a8 --- /dev/null +++ b/dbt/include/spark/macros/utils/safe_cast.sql @@ -0,0 +1,8 @@ +{% macro spark__safe_cast(field, type) %} +{%- set field_clean = field.strip('"').strip("'") if (cast_from_string_unsupported_for(type) and field is string) else field -%} +cast({{field_clean}} as {{type}}) +{% endmacro %} + +{% macro cast_from_string_unsupported_for(type) %} + {{ return(type.lower().startswith('struct') or type.lower().startswith('array') or type.lower().startswith('map')) }} +{% endmacro %} diff --git a/tests/functional/adapter/unit_testing/test_unit_testing.py b/tests/functional/adapter/unit_testing/test_unit_testing.py new file mode 100644 index 000000000..b70c581d1 --- /dev/null +++ b/tests/functional/adapter/unit_testing/test_unit_testing.py @@ -0,0 +1,34 @@ +import pytest + +from dbt.tests.adapter.unit_testing.test_types import BaseUnitTestingTypes +from dbt.tests.adapter.unit_testing.test_case_insensitivity import BaseUnitTestCaseInsensivity +from dbt.tests.adapter.unit_testing.test_invalid_input import BaseUnitTestInvalidInput + + +class TestSparkUnitTestingTypes(BaseUnitTestingTypes): + @pytest.fixture + def data_types(self): + # sql_value, yaml_value + return [ + ["1", "1"], + ["2.0", "2.0"], + ["'12345'", "12345"], + ["'string'", "string"], + ["true", "true"], + ["date '2011-11-11'", "2011-11-11"], + ["timestamp '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"], + ["array(1, 2, 3)", "'array(1, 2, 3)'"], + [ + "map('10', 't', '15', 'f', '20', NULL)", + """'map("10", "t", "15", "f", "20", NULL)'""", + ], + ['named_struct("a", 1, "b", 2, "c", 3)', """'named_struct("a", 1, "b", 2, "c", 3)'"""], + ] + + +class TestSparkUnitTestCaseInsensitivity(BaseUnitTestCaseInsensivity): + pass + + +class TestSparkUnitTestInvalidInput(BaseUnitTestInvalidInput): + pass