diff --git a/.changes/unreleased/Fixes-20230918-105721.yaml b/.changes/unreleased/Fixes-20230918-105721.yaml new file mode 100644 index 000000000..a0bd9eee6 --- /dev/null +++ b/.changes/unreleased/Fixes-20230918-105721.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Make python models use transient config +time: 2023-09-18T10:57:21.113134+12:00 +custom: + Author: jeremyyeo + Issue: "776" diff --git a/dbt/include/snowflake/macros/materializations/table.sql b/dbt/include/snowflake/macros/materializations/table.sql index 628474caa..ef201c705 100644 --- a/dbt/include/snowflake/macros/materializations/table.sql +++ b/dbt/include/snowflake/macros/materializations/table.sql @@ -38,7 +38,18 @@ {% endmaterialization %} -{% macro py_write_table(compiled_code, target_relation, temporary=False) %} +{% macro py_write_table(compiled_code, target_relation, temporary=False, table_type=none) %} +{#- The following logic is only for backwards-compatiblity with deprecated `temporary` parameter -#} +{% if table_type is not none %} + {#- Just use the table_type as-is -#} +{% elif temporary -%} + {#- Case 1 when the deprecated `temporary` parameter is used without the replacement `table_type` parameter -#} + {%- set table_type = "temporary" -%} +{% else %} + {#- Case 2 when the deprecated `temporary` parameter is used without the replacement `table_type` parameter -#} + {#- Snowflake treats "" as meaning "permanent" -#} + {%- set table_type = "" -%} +{%- endif %} {{ compiled_code }} def materialize(session, df, target_relation): # make sure pandas exists @@ -52,7 +63,7 @@ def materialize(session, df, target_relation): # session.write_pandas does not have overwrite function df = session.createDataFrame(df) {% set target_relation_name = resolve_model_name(target_relation) %} - df.write.mode("overwrite").save_as_table('{{ target_relation_name }}', create_temp_table={{temporary}}) + df.write.mode("overwrite").save_as_table('{{ target_relation_name }}', table_type='{{table_type}}') def main(session): dbt = dbtObj(session.table) diff --git a/dbt/include/snowflake/macros/relations/table/create.sql b/dbt/include/snowflake/macros/relations/table/create.sql index 8924af00a..c6bc8f775 100644 --- a/dbt/include/snowflake/macros/relations/table/create.sql +++ b/dbt/include/snowflake/macros/relations/table/create.sql @@ -1,6 +1,15 @@ {% macro snowflake__create_table_as(temporary, relation, compiled_code, language='sql') -%} + {%- set transient = config.get('transient', default=true) -%} + + {% if temporary -%} + {%- set table_type = "temporary" -%} + {%- elif transient -%} + {%- set table_type = "transient" -%} + {%- else -%} + {%- set table_type = "" -%} + {%- endif %} + {%- if language == 'sql' -%} - {%- set transient = config.get('transient', default=true) -%} {%- set cluster_by_keys = config.get('cluster_by', default=none) -%} {%- set enable_automatic_clustering = config.get('automatic_clustering', default=false) -%} {%- set copy_grants = config.get('copy_grants', default=false) -%} @@ -17,11 +26,7 @@ {{ sql_header if sql_header is not none }} - create or replace {% if temporary -%} - temporary - {%- elif transient -%} - transient - {%- endif %} table {{ relation }} + create or replace {{ table_type }} table {{ relation }} {%- set contract_config = config.get('contract') -%} {%- if contract_config.enforced -%} {{ get_assert_columns_equivalent(sql) }} @@ -46,7 +51,7 @@ {%- endif -%} {%- elif language == 'python' -%} - {{ py_write_table(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }} + {{ py_write_table(compiled_code=compiled_code, target_relation=relation, table_type=table_type) }} {%- else -%} {% do exceptions.raise_compiler_error("snowflake__create_table_as macro didn't get supported language, it got %s" % language) %} {%- endif -%} diff --git a/tests/functional/adapter/python_model_tests/_files.py b/tests/functional/adapter/python_model_tests/_files.py new file mode 100644 index 000000000..dd69f37fa --- /dev/null +++ b/tests/functional/adapter/python_model_tests/_files.py @@ -0,0 +1,48 @@ +# __table +TRANSIENT_TRUE_TABLE = """ +import pandas + +def model(dbt, session): + dbt.config(transient=True) + return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2']) +""" + + +TRANSIENT_FALSE_TABLE = """ +import pandas + +def model(dbt, session): + dbt.config(transient=False) + return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2']) +""" + + +TRANSIENT_NONE_TABLE = """ +import pandas + +def model(dbt, session): + dbt.config(transient=None) + return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2']) +""" + + +TRANSIENT_UNSET_TABLE = """ +import pandas + +def model(dbt, session): + return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2']) +""" + + +MACRO__DESCRIBE_TABLES = """ +{% macro snowflake__test__describe_tables() %} + {%- set _sql -%} + show tables; + select "name", "kind" + from table(result_scan(last_query_id())) + {%- endset %} + {% set _table = run_query(_sql) %} + + {% do return(_table) %} +{% endmacro %} +""" diff --git a/tests/functional/adapter/python_model_tests/test_table_type.py b/tests/functional/adapter/python_model_tests/test_table_type.py new file mode 100644 index 000000000..df1f34ac4 --- /dev/null +++ b/tests/functional/adapter/python_model_tests/test_table_type.py @@ -0,0 +1,35 @@ +import pytest + +from dbt.tests.util import run_dbt + +from tests.functional.adapter.python_model_tests import _files + + +class TestTableType: + @pytest.fixture(scope="class") + def macros(self): + return {"snowflake__test__describe_tables.sql": _files.MACRO__DESCRIBE_TABLES} + + @pytest.fixture(scope="class") + def models(self): + return { + # __table + "TRANSIENT_TRUE_TABLE.py": _files.TRANSIENT_TRUE_TABLE, + "TRANSIENT_FALSE_TABLE.py": _files.TRANSIENT_FALSE_TABLE, + "TRANSIENT_NONE_TABLE.py": _files.TRANSIENT_NONE_TABLE, + "TRANSIENT_UNSET_TABLE.py": _files.TRANSIENT_UNSET_TABLE, + } + + def test_expected_table_types_are_created(self, project): + run_dbt(["run"]) + expected_table_types = { + # (name, kind) - TABLE == permanent + ("TRANSIENT_TRUE_TABLE", "TRANSIENT"), + ("TRANSIENT_FALSE_TABLE", "TABLE"), + ("TRANSIENT_NONE_TABLE", "TABLE"), + ("TRANSIENT_UNSET_TABLE", "TRANSIENT"), + } + with project.adapter.connection_named("__test"): + agate_table = project.adapter.execute_macro("snowflake__test__describe_tables") + actual_table_types = {(row.get("name"), row.get("kind")) for row in agate_table.rows} + assert actual_table_types == expected_table_types