diff --git a/tests/functional/relation_tests/base.py b/tests/functional/relation_tests/base.py index 887bc7324..d08a6945b 100644 --- a/tests/functional/relation_tests/base.py +++ b/tests/functional/relation_tests/base.py @@ -23,11 +23,20 @@ """ -MACRO__GET_RENAME_SQL = """ -{% macro test__get_rename_sql(database, schema, identifier, relation_type, new_name) -%} +MACRO__GET_CREATE_BACKUP_SQL = """ +{% macro test__get_create_backup_sql(database, schema, identifier, relation_type) -%} {%- set relation = adapter.Relation.create(database=database, schema=schema, identifier=identifier, type=relation_type) -%} - {% call statement('test__get_rename_sql') -%} - {{ get_rename_sql(relation, new_name) }} + {% call statement('test__get_create_backup_sql') -%} + {{ get_create_backup_sql(relation) }} + {%- endcall %} +{% endmacro %}""" + + +MACRO__GET_RENAME_INTERMEDIATE_SQL = """ +{% macro test__get_rename_intermediate_sql(database, schema, identifier, relation_type) -%} + {%- set relation = adapter.Relation.create(database=database, schema=schema, identifier=identifier, type=relation_type) -%} + {% call statement('test__get_rename_intermediate_sql') -%} + {{ get_rename_intermediate_sql(relation) }} {%- endcall %} {% endmacro %}""" @@ -41,12 +50,17 @@ def seeds(self): def models(self): yield { "my_table.sql": TABLE, + "my_table__dbt_tmp.sql": TABLE, "my_view.sql": VIEW, + "my_view__dbt_tmp.sql": VIEW, } @pytest.fixture(scope="class") def macros(self): - yield {"test__get_rename_sql.sql": MACRO__GET_RENAME_SQL} + yield { + "test__get_create_backup_sql.sql": MACRO__GET_CREATE_BACKUP_SQL, + "test__get_rename_intermediate_sql.sql": MACRO__GET_RENAME_INTERMEDIATE_SQL, + } @pytest.fixture(scope="class", autouse=True) def setup(self, project): diff --git a/tests/functional/relation_tests/test_table.py b/tests/functional/relation_tests/test_table.py index 20ce8e305..b4a8709ea 100644 --- a/tests/functional/relation_tests/test_table.py +++ b/tests/functional/relation_tests/test_table.py @@ -3,16 +3,23 @@ class TestTable(RelationOperation): - def test_get_rename_table_sql(self, project): + def test_get_create_backup_and_rename_intermediate_sql(self, project): args = { "database": project.database, "schema": project.test_schema, "identifier": "my_table", "relation_type": "table", - "new_name": "my_new_table", } expected_statement = ( f"alter table {project.database}.{project.test_schema}.my_table " - f"rename to {project.database}.{project.test_schema}.my_new_table" + f"rename to {project.database}.{project.test_schema}.my_table__dbt_backup" + ) + self.assert_operation(project, "test__get_create_backup_sql", args, expected_statement) + + expected_statement = ( + f"alter table {project.database}.{project.test_schema}.my_table__dbt_tmp " + f"rename to {project.database}.{project.test_schema}.my_table" + ) + self.assert_operation( + project, "test__get_rename_intermediate_sql", args, expected_statement ) - self.assert_operation(project, "test__get_rename_sql", args, expected_statement) diff --git a/tests/functional/relation_tests/test_view.py b/tests/functional/relation_tests/test_view.py index 0f338365e..721455da1 100644 --- a/tests/functional/relation_tests/test_view.py +++ b/tests/functional/relation_tests/test_view.py @@ -3,16 +3,23 @@ class TestView(RelationOperation): - def test_get_rename_view_sql(self, project): + def test_get_create_backup_and_rename_intermediate_sql(self, project): args = { "database": project.database, "schema": project.test_schema, "identifier": "my_view", "relation_type": "view", - "new_name": "my_new_view", } expected_statement = ( f"alter view {project.database}.{project.test_schema}.my_view " - f"rename to {project.database}.{project.test_schema}.my_new_view" + f"rename to {project.database}.{project.test_schema}.my_view__dbt_backup" + ) + self.assert_operation(project, "test__get_create_backup_sql", args, expected_statement) + + expected_statement = ( + f"alter view {project.database}.{project.test_schema}.my_view__dbt_tmp " + f"rename to {project.database}.{project.test_schema}.my_view" + ) + self.assert_operation( + project, "test__get_rename_intermediate_sql", args, expected_statement ) - self.assert_operation(project, "test__get_rename_sql", args, expected_statement)