From 4e3b10c0463328ffa0cc3e8b69af207a2b7a4bbb Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 20 Feb 2024 17:27:05 -0500 Subject: [PATCH] implement safe_cast + add tests for array, map, named_struct for unit testing --- dbt/include/spark/macros/utils/safe_cast.sql | 11 +++++++++++ .../adapter/unit_testing/test_unit_testing.py | 9 ++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 dbt/include/spark/macros/utils/safe_cast.sql diff --git a/dbt/include/spark/macros/utils/safe_cast.sql b/dbt/include/spark/macros/utils/safe_cast.sql new file mode 100644 index 000000000..c8922c5f8 --- /dev/null +++ b/dbt/include/spark/macros/utils/safe_cast.sql @@ -0,0 +1,11 @@ +{% macro spark__safe_cast(field, type) %} +{%- if cast_from_string_unsupported_for(type) and field is string -%} + cast({{field.strip('"').strip("'")}} as {{type}}) +{%- else -%} + safe_cast({{field}} as {{type}}) +{%- endif -%} +{% endmacro %} + +{% macro cast_from_string_unsupported_for(type) %} + {{ return(type.lower().startswith('struct') or type.lower().startswith('array') or type.lower().startswith('map')) }} +{% endmacro %} diff --git a/tests/functional/adapter/unit_testing/test_unit_testing.py b/tests/functional/adapter/unit_testing/test_unit_testing.py index 45dbc3356..b70c581d1 100644 --- a/tests/functional/adapter/unit_testing/test_unit_testing.py +++ b/tests/functional/adapter/unit_testing/test_unit_testing.py @@ -17,9 +17,12 @@ def data_types(self): ["true", "true"], ["date '2011-11-11'", "2011-11-11"], ["timestamp '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"], - # ["map(struct('Hello', 'World'), 'Greeting')", '''"map(struct('Hello', 'World'), 'Greeting')"'''], - # ['named_struct("a", 1, "b", 2, "c", 3)', """'named_struct("a", 1, "b", 2, "c", 3)'"""], - # ["array(1, 2, 3)", "'array(1, 2, 3)'"], + ["array(1, 2, 3)", "'array(1, 2, 3)'"], + [ + "map('10', 't', '15', 'f', '20', NULL)", + """'map("10", "t", "15", "f", "20", NULL)'""", + ], + ['named_struct("a", 1, "b", 2, "c", 3)', """'named_struct("a", 1, "b", 2, "c", 3)'"""], ]