Skip to content

Commit

Permalink
Respect transient config for dbt Python models (dbt-labs#802)
Browse files Browse the repository at this point in the history
* add transient configs

* Backwards-compatible handling of `temporary` and `transient` configs for dbt python models

* Fix Jinja syntax errors

* add test for table_type on python models

---------

Co-authored-by: jeremyyeo <[email protected]>
Co-authored-by: colin-rogers-dbt <[email protected]>
Co-authored-by: Mike Alfare <[email protected]>
Co-authored-by: Mike Alfare <[email protected]>
  • Loading branch information
5 people authored and philippe-boyd-maxa committed Nov 27, 2023
1 parent 2a10e81 commit a39bdf3
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230918-105721.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Make python models use transient config
time: 2023-09-18T10:57:21.113134+12:00
custom:
Author: jeremyyeo
Issue: "776"
15 changes: 13 additions & 2 deletions dbt/include/snowflake/macros/materializations/table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,18 @@

{% endmaterialization %}

{% macro py_write_table(compiled_code, target_relation, temporary=False) %}
{% macro py_write_table(compiled_code, target_relation, temporary=False, table_type=none) %}
{#- The following logic is only for backwards-compatiblity with deprecated `temporary` parameter -#}
{% if table_type is not none %}
{#- Just use the table_type as-is -#}
{% elif temporary -%}
{#- Case 1 when the deprecated `temporary` parameter is used without the replacement `table_type` parameter -#}
{%- set table_type = "temporary" -%}
{% else %}
{#- Case 2 when the deprecated `temporary` parameter is used without the replacement `table_type` parameter -#}
{#- Snowflake treats "" as meaning "permanent" -#}
{%- set table_type = "" -%}
{%- endif %}
{{ compiled_code }}
def materialize(session, df, target_relation):
# make sure pandas exists
Expand All @@ -52,7 +63,7 @@ def materialize(session, df, target_relation):
# session.write_pandas does not have overwrite function
df = session.createDataFrame(df)
{% set target_relation_name = resolve_model_name(target_relation) %}
df.write.mode("overwrite").save_as_table('{{ target_relation_name }}', create_temp_table={{temporary}})
df.write.mode("overwrite").save_as_table('{{ target_relation_name }}', table_type='{{table_type}}')

def main(session):
dbt = dbtObj(session.table)
Expand Down
19 changes: 12 additions & 7 deletions dbt/include/snowflake/macros/relations/table/create.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
{% macro snowflake__create_table_as(temporary, relation, compiled_code, language='sql') -%}
{%- set transient = config.get('transient', default=true) -%}

{% if temporary -%}
{%- set table_type = "temporary" -%}
{%- elif transient -%}
{%- set table_type = "transient" -%}
{%- else -%}
{%- set table_type = "" -%}
{%- endif %}

{%- if language == 'sql' -%}
{%- set transient = config.get('transient', default=true) -%}
{%- set cluster_by_keys = config.get('cluster_by', default=none) -%}
{%- set enable_automatic_clustering = config.get('automatic_clustering', default=false) -%}
{%- set copy_grants = config.get('copy_grants', default=false) -%}
Expand All @@ -17,11 +26,7 @@

{{ sql_header if sql_header is not none }}

create or replace {% if temporary -%}
temporary
{%- elif transient -%}
transient
{%- endif %} table {{ relation }}
create or replace {{ table_type }} table {{ relation }}
{%- set contract_config = config.get('contract') -%}
{%- if contract_config.enforced -%}
{{ get_assert_columns_equivalent(sql) }}
Expand All @@ -46,7 +51,7 @@
{%- endif -%}

{%- elif language == 'python' -%}
{{ py_write_table(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }}
{{ py_write_table(compiled_code=compiled_code, target_relation=relation, table_type=table_type) }}
{%- else -%}
{% do exceptions.raise_compiler_error("snowflake__create_table_as macro didn't get supported language, it got %s" % language) %}
{%- endif -%}
Expand Down
48 changes: 48 additions & 0 deletions tests/functional/adapter/python_model_tests/_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# <temporary>_<transient>_table
TRANSIENT_TRUE_TABLE = """
import pandas
def model(dbt, session):
dbt.config(transient=True)
return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2'])
"""


TRANSIENT_FALSE_TABLE = """
import pandas
def model(dbt, session):
dbt.config(transient=False)
return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2'])
"""


TRANSIENT_NONE_TABLE = """
import pandas
def model(dbt, session):
dbt.config(transient=None)
return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2'])
"""


TRANSIENT_UNSET_TABLE = """
import pandas
def model(dbt, session):
return pandas.DataFrame([[1,2]] * 10, columns=['test', 'test2'])
"""


MACRO__DESCRIBE_TABLES = """
{% macro snowflake__test__describe_tables() %}
{%- set _sql -%}
show tables;
select "name", "kind"
from table(result_scan(last_query_id()))
{%- endset %}
{% set _table = run_query(_sql) %}
{% do return(_table) %}
{% endmacro %}
"""
35 changes: 35 additions & 0 deletions tests/functional/adapter/python_model_tests/test_table_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from dbt.tests.util import run_dbt

from tests.functional.adapter.python_model_tests import _files


class TestTableType:
@pytest.fixture(scope="class")
def macros(self):
return {"snowflake__test__describe_tables.sql": _files.MACRO__DESCRIBE_TABLES}

@pytest.fixture(scope="class")
def models(self):
return {
# <temporary>_<transient>_table
"TRANSIENT_TRUE_TABLE.py": _files.TRANSIENT_TRUE_TABLE,
"TRANSIENT_FALSE_TABLE.py": _files.TRANSIENT_FALSE_TABLE,
"TRANSIENT_NONE_TABLE.py": _files.TRANSIENT_NONE_TABLE,
"TRANSIENT_UNSET_TABLE.py": _files.TRANSIENT_UNSET_TABLE,
}

def test_expected_table_types_are_created(self, project):
run_dbt(["run"])
expected_table_types = {
# (name, kind) - TABLE == permanent
("TRANSIENT_TRUE_TABLE", "TRANSIENT"),
("TRANSIENT_FALSE_TABLE", "TABLE"),
("TRANSIENT_NONE_TABLE", "TABLE"),
("TRANSIENT_UNSET_TABLE", "TRANSIENT"),
}
with project.adapter.connection_named("__test"):
agate_table = project.adapter.execute_macro("snowflake__test__describe_tables")
actual_table_types = {(row.get("name"), row.get("kind")) for row in agate_table.rows}
assert actual_table_types == expected_table_types

0 comments on commit a39bdf3

Please sign in to comment.