diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1141ccc97..b748e03ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,5 @@ # For more on configuring pre-commit hooks (see https://pre-commit.com/) -# TODO: remove global exclusion of tests when testing overhaul is complete -exclude: '^tests/.*' - # Force all unspecified python hooks to run python 3.8 default_language_version: python: python3 diff --git a/pytest.ini b/pytest.ini index b04a6ccf3..b3d74bc14 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,5 +6,4 @@ env_files = test.env testpaths = tests/unit - tests/integration tests/functional diff --git a/tests/conftest.py b/tests/conftest.py index 18fcbb714..96f0d43e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,12 @@ @pytest.fixture(scope="class") def dbt_profile_target(): return { - 'type': 'redshift', - 'threads': 1, - 'retries': 6, - 'host': os.getenv('REDSHIFT_TEST_HOST'), - 'port': int(os.getenv('REDSHIFT_TEST_PORT')), - 'user': os.getenv('REDSHIFT_TEST_USER'), - 'pass': os.getenv('REDSHIFT_TEST_PASS'), - 'dbname': os.getenv('REDSHIFT_TEST_DBNAME'), + "type": "redshift", + "threads": 1, + "retries": 6, + "host": os.getenv("REDSHIFT_TEST_HOST"), + "port": int(os.getenv("REDSHIFT_TEST_PORT")), + "user": os.getenv("REDSHIFT_TEST_USER"), + "pass": os.getenv("REDSHIFT_TEST_PASS"), + "dbname": os.getenv("REDSHIFT_TEST_DBNAME"), } diff --git a/tests/functional/adapter/common.py b/tests/functional/adapter/common.py index 914e3fcf8..ce7c0903b 100644 --- a/tests/functional/adapter/common.py +++ b/tests/functional/adapter/common.py @@ -4,7 +4,9 @@ from dbt.tests.fixtures.project import TestProjInfo -def get_records(project: TestProjInfo, table: str, select: str = None, where: str = None) -> List[tuple]: +def get_records( + project: TestProjInfo, table: str, select: str = None, where: str = None +) -> List[tuple]: """ Gets records from a single table in a dbt project @@ -39,7 +41,9 @@ def update_records(project: TestProjInfo, table: str, updates: Dict[str, str], w where: the where clause to apply, if any; defaults to all records """ table_name = relation_from_name(project.adapter, table) - set_clause = ', '.join([' = '.join([field, expression]) for field, expression in updates.items()]) + set_clause = ", ".join( + [" = ".join([field, expression]) for field, expression in updates.items()] + ) where_clause = where or "1 = 1" sql = f""" update {table_name} @@ -49,7 +53,9 @@ def update_records(project: TestProjInfo, table: str, updates: Dict[str, str], w project.run_sql(sql) -def insert_records(project: TestProjInfo, to_table: str, from_table: str, select: str, where: str = None): +def insert_records( + project: TestProjInfo, to_table: str, from_table: str, select: str, where: str = None +): """ Inserts records from one table into another table in a dbt project @@ -91,7 +97,9 @@ def delete_records(project: TestProjInfo, table: str, where: str = None): project.run_sql(sql) -def clone_table(project: TestProjInfo, to_table: str, from_table: str, select: str, where: str = None): +def clone_table( + project: TestProjInfo, to_table: str, from_table: str, select: str, where: str = None +): """ Creates a new table based on another table in a dbt project diff --git a/tests/functional/adapter/conftest.py b/tests/functional/adapter/conftest.py index e4aa4fe31..c5c980154 100644 --- a/tests/functional/adapter/conftest.py +++ b/tests/functional/adapter/conftest.py @@ -21,5 +21,5 @@ def test_setting_reflects_config_option(self, model_ddl: str, backup_expected: b In this example, the fixture returns the contents of the backup_is_false DDL file as a string. This string is then referenced in the test as model_ddl. """ - with open(f"target/run/test/models/{request.param}.sql", 'r') as ddl_file: - yield '\n'.join(ddl_file.readlines()) + with open(f"target/run/test/models/{request.param}.sql", "r") as ddl_file: + yield "\n".join(ddl_file.readlines()) diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 192097bc5..7b73d212b 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -1,4 +1,7 @@ -from dbt.tests.adapter.incremental.test_incremental_on_schema_change import BaseIncrementalOnSchemaChange +from dbt.tests.adapter.incremental.test_incremental_on_schema_change import ( + BaseIncrementalOnSchemaChange, +) + class TestIncrementalOnSchemaChange(BaseIncrementalOnSchemaChange): pass diff --git a/tests/functional/adapter/incremental/test_incremental_unique_id.py b/tests/functional/adapter/incremental/test_incremental_unique_id.py index 14e9b7ea8..5fcdfbe16 100644 --- a/tests/functional/adapter/incremental/test_incremental_unique_id.py +++ b/tests/functional/adapter/incremental/test_incremental_unique_id.py @@ -2,4 +2,4 @@ class TestUniqueKeyRedshift(BaseIncrementalUniqueKey): - pass \ No newline at end of file + pass diff --git a/tests/functional/adapter/snapshot_tests/test_snapshot.py b/tests/functional/adapter/snapshot_tests/test_snapshot.py index 0f6153f47..9a4ef7694 100644 --- a/tests/functional/adapter/snapshot_tests/test_snapshot.py +++ b/tests/functional/adapter/snapshot_tests/test_snapshot.py @@ -16,7 +16,6 @@ class SnapshotBase: - @pytest.fixture(scope="class") def seeds(self): """ @@ -80,9 +79,9 @@ def delete_snapshot_records(self): common.delete_records(self.project, "snapshot") def _assert_results( - self, - ids_with_current_snapshot_records: Iterable, - ids_with_closed_out_snapshot_records: Iterable + self, + ids_with_current_snapshot_records: Iterable, + ids_with_closed_out_snapshot_records: Iterable, ): """ All test cases are checked by considering whether a source record's id has a value in `dbt_valid_to` @@ -106,13 +105,12 @@ def _assert_results( records = set(self.get_snapshot_records("id, dbt_valid_to is null as is_current")) expected_records = set().union( {(i, True) for i in ids_with_current_snapshot_records}, - {(i, False) for i in ids_with_closed_out_snapshot_records} + {(i, False) for i in ids_with_closed_out_snapshot_records}, ) assert records == expected_records class TestSnapshot(SnapshotBase): - @pytest.fixture(scope="class") def snapshots(self): return {"snapshot.sql": snapshots.SNAPSHOT_TIMESTAMP_SQL} @@ -121,11 +119,13 @@ def test_updates_are_captured_by_snapshot(self, project): """ Update the last 5 records. Show that all ids are current, but the last 5 reflect updates. """ - self.update_fact_records({"updated_at": "updated_at + interval '1 day'"}, "id between 16 and 20") + self.update_fact_records( + {"updated_at": "updated_at + interval '1 day'"}, "id between 16 and 20" + ) run_dbt(["snapshot"]) self._assert_results( ids_with_current_snapshot_records=range(1, 21), - ids_with_closed_out_snapshot_records=range(16, 21) + ids_with_closed_out_snapshot_records=range(16, 21), ) def test_inserts_are_captured_by_snapshot(self, project): @@ -135,8 +135,7 @@ def test_inserts_are_captured_by_snapshot(self, project): self.insert_fact_records("id between 21 and 30") run_dbt(["snapshot"]) self._assert_results( - ids_with_current_snapshot_records=range(1, 31), - ids_with_closed_out_snapshot_records=[] + ids_with_current_snapshot_records=range(1, 31), ids_with_closed_out_snapshot_records=[] ) def test_deletes_are_captured_by_snapshot(self, project): @@ -147,7 +146,7 @@ def test_deletes_are_captured_by_snapshot(self, project): run_dbt(["snapshot"]) self._assert_results( ids_with_current_snapshot_records=range(1, 16), - ids_with_closed_out_snapshot_records=range(16, 21) + ids_with_closed_out_snapshot_records=range(16, 21), ) def test_revives_are_captured_by_snapshot(self, project): @@ -161,7 +160,7 @@ def test_revives_are_captured_by_snapshot(self, project): run_dbt(["snapshot"]) self._assert_results( ids_with_current_snapshot_records=range(1, 19), - ids_with_closed_out_snapshot_records=range(16, 21) + ids_with_closed_out_snapshot_records=range(16, 21), ) def test_new_column_captured_by_snapshot(self, project): @@ -176,17 +175,16 @@ def test_new_column_captured_by_snapshot(self, project): "full_name": "first_name || ' ' || last_name", "updated_at": "updated_at + interval '1 day'", }, - "id between 11 and 20" + "id between 11 and 20", ) run_dbt(["snapshot"]) self._assert_results( ids_with_current_snapshot_records=range(1, 21), - ids_with_closed_out_snapshot_records=range(11, 21) + ids_with_closed_out_snapshot_records=range(11, 21), ) class TestSnapshotCheck(SnapshotBase): - @pytest.fixture(scope="class") def snapshots(self): return {"snapshot.sql": snapshots.SNAPSHOT_CHECK_SQL} @@ -197,10 +195,12 @@ def test_column_selection_is_reflected_in_snapshot(self, project): Update the middle 10 records on a tracked column. (hence records 6-10 are updated on both) Show that all ids are current, and only the tracked column updates are reflected in `snapshot`. """ - self.update_fact_records({"last_name": "left(last_name, 3)"}, "id between 1 and 10") # not tracked - self.update_fact_records({"email": "left(email, 3)"}, "id between 6 and 15") # tracked + self.update_fact_records( + {"last_name": "left(last_name, 3)"}, "id between 1 and 10" + ) # not tracked + self.update_fact_records({"email": "left(email, 3)"}, "id between 6 and 15") # tracked run_dbt(["snapshot"]) self._assert_results( ids_with_current_snapshot_records=range(1, 21), - ids_with_closed_out_snapshot_records=range(6, 16) + ids_with_closed_out_snapshot_records=range(6, 16), ) diff --git a/tests/functional/adapter/test_backup_table.py b/tests/functional/adapter/test_backup_table.py index ee58615b3..4c24250f1 100644 --- a/tests/functional/adapter/test_backup_table.py +++ b/tests/functional/adapter/test_backup_table.py @@ -59,14 +59,12 @@ class BackupTableBase: - @pytest.fixture(scope="class", autouse=True) def _run_dbt(self, project): run_dbt(["run"]) class TestBackupTableOption(BackupTableBase): - @pytest.fixture(scope="class") def models(self): return { @@ -84,7 +82,7 @@ def models(self): ("backup_is_undefined", True), ("backup_is_true_view", True), ], - indirect=["model_ddl"] + indirect=["model_ddl"], ) def test_setting_reflects_config_option(self, model_ddl: str, backup_expected: bool): """ @@ -102,7 +100,6 @@ def test_setting_reflects_config_option(self, model_ddl: str, backup_expected: b class TestBackupTableSyntax(BackupTableBase): - @pytest.fixture(scope="class") def models(self): return { @@ -116,7 +113,7 @@ def models(self): ("syntax_with_distkey", "diststyle key distkey"), ("syntax_with_sortkey", "compound sortkey"), ], - indirect=["model_ddl"] + indirect=["model_ddl"], ) def test_backup_predicate_precedes_secondary_predicates(self, model_ddl, search_phrase): """ @@ -133,7 +130,6 @@ def test_backup_predicate_precedes_secondary_predicates(self, model_ddl, search_ class TestBackupTableProjectDefault(BackupTableBase): - @pytest.fixture(scope="class") def project_config_update(self): return {"models": {"backup": False}} @@ -147,11 +143,8 @@ def models(self): @pytest.mark.parametrize( "model_ddl,backup_expected", - [ - ("backup_is_true", True), - ("backup_is_undefined", False) - ], - indirect=["model_ddl"] + [("backup_is_true", True), ("backup_is_undefined", False)], + indirect=["model_ddl"], ) def test_setting_defaults_to_project_option(self, model_ddl: str, backup_expected: bool): """ diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index 06cf9948f..d2289efa3 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -12,10 +12,17 @@ from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp from dbt.tests.adapter.basic.test_adapter_methods import BaseAdapterMethod from dbt.tests.adapter.basic.test_docs_generate import BaseDocsGenerate, BaseDocsGenReferences -from dbt.tests.adapter.basic.expected_catalog import base_expected_catalog, no_stats, expected_references_catalog +from dbt.tests.adapter.basic.expected_catalog import ( + base_expected_catalog, + no_stats, + expected_references_catalog, +) from dbt.tests.adapter.basic.files import seeds_base_csv, seeds_added_csv, seeds_newcolumns_csv -from tests.functional.adapter.expected_stats import redshift_stats, redshift_ephemeral_summary_stats +from tests.functional.adapter.expected_stats import ( + redshift_stats, + redshift_ephemeral_summary_stats, +) # set the datatype of the name column in the 'added' seed so that it can hold the '_update' that's added @@ -86,19 +93,19 @@ class TestBaseAdapterMethod(BaseAdapterMethod): class TestDocsGenerateRedshift(BaseDocsGenerate): - @pytest.fixture(scope="class") + @pytest.fixture(scope="class") def expected_catalog(self, project, profile_user): return base_expected_catalog( - project, - role=profile_user, - id_type="integer", + project, + role=profile_user, + id_type="integer", text_type=AnyStringWith("character varying"), time_type="timestamp without time zone", - view_type="VIEW", - table_type="BASE TABLE", + view_type="VIEW", + table_type="BASE TABLE", model_stats=no_stats(), seed_stats=redshift_stats(), - ) + ) # TODO: update this or delete it diff --git a/tests/functional/adapter/test_changing_relation_type.py b/tests/functional/adapter/test_changing_relation_type.py index 1f0ba15ad..81ba99918 100644 --- a/tests/functional/adapter/test_changing_relation_type.py +++ b/tests/functional/adapter/test_changing_relation_type.py @@ -1,4 +1,5 @@ from dbt.tests.adapter.relations.test_changing_relation_type import BaseChangeRelationTypeValidator + class TestRedshiftChangeRelationTypes(BaseChangeRelationTypeValidator): - pass \ No newline at end of file + pass diff --git a/tests/functional/adapter/test_column_types.py b/tests/functional/adapter/test_column_types.py index 81d5ca0fa..e24167456 100644 --- a/tests/functional/adapter/test_column_types.py +++ b/tests/functional/adapter/test_column_types.py @@ -1,7 +1,7 @@ import pytest from dbt.tests.adapter.column_types.test_column_types import BaseColumnTypes -_MODEL_SQL = """ +_MODEL_SQL = """ select 1::smallint as smallint_col, 2::int as int_col, @@ -46,14 +46,11 @@ text_col: ['string', 'not number'] """ -class TestRedshiftColumnTypes(BaseColumnTypes): +class TestRedshiftColumnTypes(BaseColumnTypes): @pytest.fixture(scope="class") def models(self): - return { - "model.sql": _MODEL_SQL, - "schema.yml": _SCHEMA_YML - } + return {"model.sql": _MODEL_SQL, "schema.yml": _SCHEMA_YML} def test_run_and_test(self, project): - self.run_and_test() \ No newline at end of file + self.run_and_test() diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index 94283fc3e..9918b5037 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -3,7 +3,7 @@ from dbt.tests.adapter.constraints.test_constraints import ( BaseTableConstraintsColumnsEqual, BaseViewConstraintsColumnsEqual, - BaseConstraintsRuntimeEnforcement + BaseConstraintsRuntimeEnforcement, ) _expected_sql_redshift = """ @@ -38,26 +38,29 @@ def data_types(self, schema_int_type, int_type, string_type): ["true", "bool", "BOOL"], ["'2013-11-03 00:00:00-07'::timestamptz", "timestamptz", "TIMESTAMPTZ"], ["'2013-11-03 00:00:00-07'::timestamp", "timestamp", "TIMESTAMP"], - ["'1'::numeric", "numeric", "NUMERIC"] + ["'1'::numeric", "numeric", "NUMERIC"], ] -class TestRedshiftTableConstraintsColumnsEqual(RedshiftColumnEqualSetup, BaseTableConstraintsColumnsEqual): +class TestRedshiftTableConstraintsColumnsEqual( + RedshiftColumnEqualSetup, BaseTableConstraintsColumnsEqual +): pass -class TestRedshiftViewConstraintsColumnsEqual(RedshiftColumnEqualSetup, BaseViewConstraintsColumnsEqual): +class TestRedshiftViewConstraintsColumnsEqual( + RedshiftColumnEqualSetup, BaseViewConstraintsColumnsEqual +): pass + class TestRedshiftConstraintsRuntimeEnforcement(BaseConstraintsRuntimeEnforcement): @pytest.fixture(scope="class") def expected_sql(self, project): relation = relation_from_name(project.adapter, "my_model") - tmp_relation = relation.incorporate( - path={"identifier": relation.identifier + "__dbt_tmp"} - ) + tmp_relation = relation.incorporate(path={"identifier": relation.identifier + "__dbt_tmp"}) return _expected_sql_redshift.format(tmp_relation) - + @pytest.fixture(scope="class") def expected_error_messages(self): - return ['Cannot insert a NULL value into column id'] + return ["Cannot insert a NULL value into column id"] diff --git a/tests/functional/adapter/test_grants.py b/tests/functional/adapter/test_grants.py index bbad59f96..b627e450a 100644 --- a/tests/functional/adapter/test_grants.py +++ b/tests/functional/adapter/test_grants.py @@ -1,7 +1,5 @@ -import pytest from dbt.tests.adapter.grants.test_model_grants import BaseModelGrants from dbt.tests.adapter.grants.test_incremental_grants import BaseIncrementalGrants -from dbt.tests.adapter.grants.test_invalid_grants import BaseInvalidGrants from dbt.tests.adapter.grants.test_seed_grants import BaseSeedGrants from dbt.tests.adapter.grants.test_snapshot_grants import BaseSnapshotGrants diff --git a/tests/functional/adapter/test_late_binding_view.py b/tests/functional/adapter/test_late_binding_view.py index 7c7bfa69d..013bf06be 100644 --- a/tests/functional/adapter/test_late_binding_view.py +++ b/tests/functional/adapter/test_late_binding_view.py @@ -18,7 +18,6 @@ class TestLateBindingView: - @pytest.fixture(scope="class") def models(self): return { @@ -27,20 +26,18 @@ def models(self): @pytest.fixture(scope="class") def seeds(self): - return { - "seed.csv": _SEED_CSV - } + return {"seed.csv": _SEED_CSV} @pytest.fixture(scope="class") def project_config_update(self): return { - 'seeds': { - 'quote_columns': False, + "seeds": { + "quote_columns": False, } } def test_late_binding_view_query(self, project): - seed_run_result = run_dbt(['seed']) + seed_run_result = run_dbt(["seed"]) assert len(seed_run_result) == 1 run_result = run_dbt() assert len(run_result) == 1 diff --git a/tests/functional/adapter/test_macros.py b/tests/functional/adapter/test_macros.py index 0994cae28..0596ab549 100644 --- a/tests/functional/adapter/test_macros.py +++ b/tests/functional/adapter/test_macros.py @@ -22,33 +22,29 @@ {% endmacro %} {% macro dispatch_to_parent() %} - {% set macro = adapter.dispatch('dispatch_to_parent') %} - {{ macro() }} + {% set macro = adapter.dispatch('dispatch_to_parent') %} + {{ macro() }} {% endmacro %} {% macro default__dispatch_to_parent() %} - {% set msg = 'No default implementation of dispatch_to_parent' %} + {% set msg = 'No default implementation of dispatch_to_parent' %} {{ exceptions.raise_compiler_error(msg) }} {% endmacro %} {% macro postgres__dispatch_to_parent() %} - {{ return('') }} + {{ return('') }} {% endmacro %} """ -class TestRedshift: +class TestRedshift: @pytest.fixture(scope="class") def macros(self): - return { - "macro.sql": _MACRO_SQL - } + return {"macro.sql": _MACRO_SQL} @pytest.fixture(scope="class") def models(self): - return { - "model.sql": _MODEL_SQL - } + return {"model.sql": _MODEL_SQL} def test_inherited_macro(self, project): - run_dbt() \ No newline at end of file + run_dbt() diff --git a/tests/functional/adapter/test_persist_docs.py b/tests/functional/adapter/test_persist_docs.py index 4d18f8ec6..61b8bd5a6 100644 --- a/tests/functional/adapter/test_persist_docs.py +++ b/tests/functional/adapter/test_persist_docs.py @@ -27,31 +27,31 @@ class TestPersistDocsLateBinding(BasePersistDocsBase): @pytest.fixture(scope="class") def project_config_update(self): return { - 'models': { - 'test': { - '+persist_docs': { + "models": { + "test": { + "+persist_docs": { "relation": True, "columns": True, }, - 'view_model': { - 'bind': False, - } + "view_model": { + "bind": False, + }, } } } def test_comment_on_late_binding_view(self, project): run_dbt() - run_dbt(['docs', 'generate']) - with open('target/catalog.json') as fp: + run_dbt(["docs", "generate"]) + with open("target/catalog.json") as fp: catalog_data = json.load(fp) - assert 'nodes' in catalog_data - assert len(catalog_data['nodes']) == 4 - table_node = catalog_data['nodes']['model.test.table_model'] + assert "nodes" in catalog_data + assert len(catalog_data["nodes"]) == 4 + table_node = catalog_data["nodes"]["model.test.table_model"] view_node = self._assert_has_table_comments(table_node) - view_node = catalog_data['nodes']['model.test.view_model'] + view_node = catalog_data["nodes"]["model.test.view_model"] self._assert_has_view_comments(view_node, False, False) - no_docs_node = catalog_data['nodes']['model.test.no_docs_model'] + no_docs_node = catalog_data["nodes"]["model.test.no_docs_model"] self._assert_has_view_comments(no_docs_node, False, False) diff --git a/tests/functional/adapter/test_query_comment.py b/tests/functional/adapter/test_query_comment.py index 281a90867..db6a440d7 100644 --- a/tests/functional/adapter/test_query_comment.py +++ b/tests/functional/adapter/test_query_comment.py @@ -1,4 +1,3 @@ -import pytest from dbt.tests.adapter.query_comment.test_query_comment import ( BaseQueryComments, BaseMacroQueryComments, @@ -12,17 +11,22 @@ class TestQueryCommentsRedshift(BaseQueryComments): pass + class TestMacroQueryCommentsRedshift(BaseMacroQueryComments): pass + class TestMacroArgsQueryCommentsRedshift(BaseMacroArgsQueryComments): pass + class TestMacroInvalidQueryCommentsRedshift(BaseMacroInvalidQueryComments): pass + class TestNullQueryCommentsRedshift(BaseNullQueryComments): pass + class TestEmptyQueryCommentsRedshift(BaseEmptyQueryComments): - pass \ No newline at end of file + pass diff --git a/tests/functional/adapter/test_relation_name.py b/tests/functional/adapter/test_relation_name.py index 733ef7f49..f17bbda63 100644 --- a/tests/functional/adapter/test_relation_name.py +++ b/tests/functional/adapter/test_relation_name.py @@ -68,9 +68,7 @@ def setUp(self, project): @pytest.fixture(scope="class") def seeds(self): - return { - "seed.csv": seeds__seed - } + return {"seed.csv": seeds__seed} @pytest.fixture(scope="class") def project_config_update(self): @@ -84,12 +82,8 @@ def project_config_update(self): class TestAdapterDDL(TestAdapterDDLBase): @pytest.fixture(scope="class") def models(self): - relname_51_chars_long = ( - "incremental_table_whose_name_is_51_characters_abcde.sql" - ) - relname_52_chars_long = ( - "relation_whose_name_is_52_chars_long_abcdefghijklmno.sql" - ) + relname_51_chars_long = "incremental_table_whose_name_is_51_characters_abcde.sql" + relname_52_chars_long = "relation_whose_name_is_52_chars_long_abcdefghijklmno.sql" relname_63_chars_long = ( "relation_whose_name_is_63_chars_long_abcdefghijklmnopqrstuvwxyz.sql" ) @@ -110,7 +104,7 @@ def models(self): relname_63_chars_long: models__relationname_63_chars_long, relname_63_chars_long_b: models__relationname_63_chars_long, relname_64_chars_long: models__relationname_64_chars_long, - relname_127_chars_long: models__relationname_127_chars_long + relname_127_chars_long: models__relationname_127_chars_long, } def test_long_name_succeeds(self, project): @@ -127,9 +121,7 @@ def models(self): "relation_whose_name_is_127_characters89012345678901234567890123456" "78901234567890123456789012345678901234567890123456789012345678.sql" ) - return { - relname_128_chars_long: models__relationname_127_chars_long - } + return {relname_128_chars_long: models__relationname_127_chars_long} def test_too_long_of_name_fails(self, project): results = run_dbt(["run"], expect_pass=False) diff --git a/tests/functional/adapter/test_simple_seed.py b/tests/functional/adapter/test_simple_seed.py index 5c57f1895..1e8cc1cd5 100644 --- a/tests/functional/adapter/test_simple_seed.py +++ b/tests/functional/adapter/test_simple_seed.py @@ -60,9 +60,7 @@ def schema(self): @pytest.fixture(scope="class") def models(self): - return { - "models-rs.yml": _SCHEMA_YML - } + return {"models-rs.yml": _SCHEMA_YML} @staticmethod def seed_enabled_types(): @@ -74,9 +72,9 @@ def seed_enabled_types(): @staticmethod def seed_tricky_types(): return { - 'seed_id_str': 'text', - 'looks_like_a_bool': 'text', - 'looks_like_a_date': 'text', + "seed_id_str": "text", + "looks_like_a_bool": "text", + "looks_like_a_date": "text", } def test_redshift_simple_seed_with_column_override_redshift(self, project): diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index e6c0f38b4..5d6b70fbb 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -1,4 +1,6 @@ -from dbt.tests.adapter.store_test_failures_tests.test_store_test_failures import TestStoreTestFailures +from dbt.tests.adapter.store_test_failures_tests.test_store_test_failures import ( + TestStoreTestFailures, +) class RedshiftTestStoreTestFailures(TestStoreTestFailures): diff --git a/tests/functional/adapter/utils/test_data_types.py b/tests/functional/adapter/utils/test_data_types.py index 147a962b5..3201afcfb 100644 --- a/tests/functional/adapter/utils/test_data_types.py +++ b/tests/functional/adapter/utils/test_data_types.py @@ -1,4 +1,3 @@ -import pytest from dbt.tests.adapter.utils.data_types.test_type_bigint import BaseTypeBigInt from dbt.tests.adapter.utils.data_types.test_type_float import BaseTypeFloat from dbt.tests.adapter.utils.data_types.test_type_int import BaseTypeInt @@ -11,23 +10,23 @@ class TestTypeBigInt(BaseTypeBigInt): pass - + class TestTypeFloat(BaseTypeFloat): pass - + class TestTypeInt(BaseTypeInt): pass - + class TestTypeNumeric(BaseTypeNumeric): pass - + class TestTypeString(BaseTypeString): pass - + class TestTypeTimestamp(BaseTypeTimestamp): pass diff --git a/tests/functional/adapter/utils/test_timestamps.py b/tests/functional/adapter/utils/test_timestamps.py index 417bfab2b..6c525be44 100644 --- a/tests/functional/adapter/utils/test_timestamps.py +++ b/tests/functional/adapter/utils/test_timestamps.py @@ -17,4 +17,4 @@ def expected_sql(self): select getdate() as current_timestamp, getdate() as current_timestamp_in_utc_backcompat, getdate() as current_timestamp_backcompat - """ \ No newline at end of file + """ diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py index 03b9cc916..266103fbc 100644 --- a/tests/functional/adapter/utils/test_utils.py +++ b/tests/functional/adapter/utils/test_utils.py @@ -1,5 +1,3 @@ -import pytest - from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat from dbt.tests.adapter.utils.test_array_construct import BaseArrayConstruct @@ -12,7 +10,6 @@ from dbt.tests.adapter.utils.test_datediff import BaseDateDiff from dbt.tests.adapter.utils.test_date_trunc import BaseDateTrunc from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesQuote -from dbt.tests.adapter.utils.test_escape_single_quotes import BaseEscapeSingleQuotesBackslash from dbt.tests.adapter.utils.test_except import BaseExcept from dbt.tests.adapter.utils.test_hash import BaseHash from dbt.tests.adapter.utils.test_intersect import BaseIntersect diff --git a/tests/unit/mock_adapter.py b/tests/unit/mock_adapter.py index cc2861e4e..8547480d1 100644 --- a/tests/unit/mock_adapter.py +++ b/tests/unit/mock_adapter.py @@ -1,16 +1,16 @@ from unittest import mock from contextlib import contextmanager -from dbt.adapters.base import BaseAdapter, PythonJobHelper +from dbt.adapters.base import BaseAdapter def adapter_factory(): class MockAdapter(BaseAdapter): - ConnectionManager = mock.MagicMock(TYPE='mock') + ConnectionManager = mock.MagicMock(TYPE="mock") responder = mock.MagicMock() # some convenient defaults responder.quote.side_effect = lambda identifier: '"{}"'.format(identifier) - responder.date_function.side_effect = lambda: 'unitdate()' + responder.date_function.side_effect = lambda: "unitdate()" responder.is_cancelable.side_effect = lambda: False @contextmanager diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 5170fcfbf..542387c0d 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -15,11 +15,7 @@ from dbt.contracts.graph.model_config import ( NodeConfig, ) -from dbt.contracts.graph.nodes import ( - ModelNode, - DependsOn, - Macro -) +from dbt.contracts.graph.nodes import ModelNode, DependsOn, Macro from dbt.context import providers from dbt.node_types import NodeType @@ -27,73 +23,73 @@ class TestRuntimeWrapper(unittest.TestCase): def setUp(self): self.mock_config = mock.MagicMock() - self.mock_config.quoting = { - 'database': True, 'schema': True, 'identifier': True} + self.mock_config.quoting = {"database": True, "schema": True, "identifier": True} adapter_class = adapter_factory() self.mock_adapter = adapter_class(self.mock_config) self.namespace = mock.MagicMock() - self.wrapper = providers.RuntimeDatabaseWrapper( - self.mock_adapter, self.namespace) + self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace) self.responder = self.mock_adapter.responder PROFILE_DATA = { - 'target': 'test', - 'quoting': {}, - 'outputs': { - 'test': { - 'type': 'redshift', - 'host': 'localhost', - 'schema': 'analytics', - 'user': 'test', - 'pass': 'test', - 'dbname': 'test', - 'port': 1, + "target": "test", + "quoting": {}, + "outputs": { + "test": { + "type": "redshift", + "host": "localhost", + "schema": "analytics", + "user": "test", + "pass": "test", + "dbname": "test", + "port": 1, } }, } PROJECT_DATA = { - 'name': 'root', - 'version': '0.1', - 'profile': 'test', - 'project-root': os.getcwd(), - 'config-version': 2, + "name": "root", + "version": "0.1", + "profile": "test", + "project-root": os.getcwd(), + "config-version": 2, } def model(): return ModelNode( - alias='model_one', - name='model_one', - database='dbt', - schema='analytics', + alias="model_one", + name="model_one", + database="dbt", + schema="analytics", resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - original_file_path='model_one.sql', - root_path='/usr/src/app', + unique_id="model.root.model_one", + fqn=["root", "model_one"], + package_name="root", + original_file_path="model_one.sql", + root_path="/usr/src/app", refs=[], sources=[], depends_on=DependsOn(), - config=NodeConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - }), + config=NodeConfig.from_dict( + { + "enabled": True, + "materialized": "view", + "persist_docs": {}, + "post-hook": [], + "pre-hook": [], + "vars": {}, + "quoting": {}, + "column_types": {}, + "tags": [], + } + ), tags=[], - path='model_one.sql', - raw_sql='', - description='', - columns={} + path="model_one.sql", + raw_sql="", + description="", + columns={}, ) @@ -101,8 +97,8 @@ def mock_macro(name, package_name): macro = mock.MagicMock( __class__=Macro, package_name=package_name, - resource_type='macro', - unique_id=f'macro.{package_name}.{name}', + resource_type="macro", + unique_id=f"macro.{package_name}.{name}", ) # Mock(name=...) does not set the `name` attribute, this does. macro.name = name @@ -111,7 +107,7 @@ def mock_macro(name, package_name): def mock_manifest(config): manifest_macros = {} - for name in ['macro_a', 'macro_b']: + for name in ["macro_a", "macro_b"]: macro = mock_macro(name, config.project_name) manifest_macros[macro.unique_id] = macro return mock.MagicMock(macros=manifest_macros) @@ -120,47 +116,49 @@ def mock_manifest(config): def mock_model(): return mock.MagicMock( __class__=ModelNode, - alias='model_one', - name='model_one', - database='dbt', - schema='analytics', + alias="model_one", + name="model_one", + database="dbt", + schema="analytics", resource_type=NodeType.Model, - unique_id='model.root.model_one', - fqn=['root', 'model_one'], - package_name='root', - original_file_path='model_one.sql', - root_path='/usr/src/app', + unique_id="model.root.model_one", + fqn=["root", "model_one"], + package_name="root", + original_file_path="model_one.sql", + root_path="/usr/src/app", refs=[], sources=[], depends_on=DependsOn(), - config=NodeConfig.from_dict({ - 'enabled': True, - 'materialized': 'view', - 'persist_docs': {}, - 'post-hook': [], - 'pre-hook': [], - 'vars': {}, - 'quoting': {}, - 'column_types': {}, - 'tags': [], - }), + config=NodeConfig.from_dict( + { + "enabled": True, + "materialized": "view", + "persist_docs": {}, + "post-hook": [], + "pre-hook": [], + "vars": {}, + "quoting": {}, + "column_types": {}, + "tags": [], + } + ), tags=[], - path='model_one.sql', - raw_sql='', - description='', + path="model_one.sql", + raw_sql="", + description="", columns={}, ) @pytest.fixture def get_adapter(): - with mock.patch.object(providers, 'get_adapter') as patch: + with mock.patch.object(providers, "get_adapter") as patch: yield patch @pytest.fixture def get_include_paths(): - with mock.patch.object(factory, 'get_include_paths') as patch: + with mock.patch.object(factory, "get_include_paths") as patch: patch.return_value = [] yield patch @@ -177,12 +175,12 @@ def manifest_fx(config): @pytest.fixture def manifest_extended(manifest_fx): - dbt_macro = mock_macro('default__some_macro', 'dbt') + dbt_macro = mock_macro("default__some_macro", "dbt") # same namespace, same name, different pkg! - rs_macro = mock_macro('redshift__some_macro', 'dbt_redshift') + rs_macro = mock_macro("redshift__some_macro", "dbt_redshift") # same name, different package - package_default_macro = mock_macro('default__some_macro', 'root') - package_rs_macro = mock_macro('redshift__some_macro', 'root') + package_default_macro = mock_macro("default__some_macro", "root") + package_rs_macro = mock_macro("redshift__some_macro", "root") manifest_fx.macros[dbt_macro.unique_id] = dbt_macro manifest_fx.macros[rs_macro.unique_id] = rs_macro manifest_fx.macros[package_default_macro.unique_id] = package_default_macro @@ -200,8 +198,8 @@ def redshift_adapter(config, get_adapter): def test_resolve_specific(config, manifest_extended, redshift_adapter, get_include_paths): - rs_macro = manifest_extended.macros['macro.dbt_redshift.redshift__some_macro'] - package_rs_macro = manifest_extended.macros['macro.root.redshift__some_macro'] + rs_macro = manifest_extended.macros["macro.dbt_redshift.redshift__some_macro"] + package_rs_macro = manifest_extended.macros["macro.root.redshift__some_macro"] ctx = providers.generate_runtime_model_context( model=mock_model(), @@ -209,24 +207,24 @@ def test_resolve_specific(config, manifest_extended, redshift_adapter, get_inclu manifest=manifest_extended, ) - ctx['adapter'].config.dispatch + ctx["adapter"].config.dispatch # macro_a exists, but default__macro_a and redshift__macro_a do not with pytest.raises(dbt.exceptions.CompilationError): - ctx['adapter'].dispatch('macro_a').macro + ctx["adapter"].dispatch("macro_a").macro # root namespace is always preferred, unless search order is explicitly defined in 'dispatch' config - assert ctx['adapter'].dispatch('some_macro').macro is package_rs_macro - assert ctx['adapter'].dispatch('some_macro', 'dbt').macro is package_rs_macro - assert ctx['adapter'].dispatch('some_macro', 'root').macro is package_rs_macro + assert ctx["adapter"].dispatch("some_macro").macro is package_rs_macro + assert ctx["adapter"].dispatch("some_macro", "dbt").macro is package_rs_macro + assert ctx["adapter"].dispatch("some_macro", "root").macro is package_rs_macro # override 'dbt' namespace search order, dispatch to 'root' first - ctx['adapter'].config.dispatch = [{'macro_namespace': 'dbt', 'search_order': ['root', 'dbt']}] - assert ctx['adapter'].dispatch('some_macro', macro_namespace='dbt').macro is package_rs_macro + ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["root", "dbt"]}] + assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is package_rs_macro # override 'dbt' namespace search order, dispatch to 'dbt' only - ctx['adapter'].config.dispatch = [{'macro_namespace': 'dbt', 'search_order': ['dbt']}] - assert ctx['adapter'].dispatch('some_macro', macro_namespace='dbt').macro is rs_macro + ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["dbt"]}] + assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is rs_macro # override 'root' namespace search order, dispatch to 'dbt' first - ctx['adapter'].config.dispatch = [{'macro_namespace': 'root', 'search_order': ['dbt', 'root']}] + ctx["adapter"].config.dispatch = [{"macro_namespace": "root", "search_order": ["dbt", "root"]}] diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index ba5361b0b..27bcd98f8 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, call import agate -import boto3 import dbt import redshift_connector @@ -14,37 +13,41 @@ from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectError from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory -from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter +from .utils import ( + config_from_parts_or_dicts, + mock_connection, + TestAdapterConversions, + inject_adapter, +) class TestRedshiftAdapter(unittest.TestCase): - def setUp(self): profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist.test.us-east-1', - 'pass': 'password', - 'port': 5439, - 'schema': 'public' + "outputs": { + "test": { + "type": "redshift", + "dbname": "redshift", + "user": "root", + "host": "thishostshouldnotexist.test.us-east-1", + "pass": "password", + "port": 5439, + "schema": "public", } }, - 'target': 'test' + "target": "test", } project_cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/tmp/dbt/does-not-exist', - 'quoting': { - 'identifier': False, - 'schema': True, + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": True, }, - 'config-version': 2, + "config-version": 2, } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) @@ -62,174 +65,171 @@ def test_implicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - user='root', - password='password', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + user="root", + password="password", port=5439, auto_create=False, db_groups=[], timeout=30, - region='us-east-1' + region="us-east-1", ) @mock.patch("redshift_connector.connect", Mock()) def test_explicit_database_conn(self): - self.config.method = 'database' + self.config.method = "database" connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - user='root', - password='password', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + user="root", + password="password", port=5439, auto_create=False, db_groups=[], - region='us-east-1', - timeout=30 + region="us-east-1", + timeout=30, ) @mock.patch("redshift_connector.connect", Mock()) def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - cluster_id='my_redshift', - host='thishostshouldnotexist.test.us-east-1' + method="iam", cluster_id="my_redshift", host="thishostshouldnotexist.test.us-east-1" ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - db_user='root', - password='', - user='', - cluster_identifier='my_redshift', - region='us-east-1', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + db_user="root", + password="", + user="", + cluster_identifier="my_redshift", + region="us-east-1", auto_create=False, db_groups=[], profile=None, timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - cluster_id='my_redshift', - iam_profile='test', - host='thishostshouldnotexist.test.us-east-1' + method="iam", + cluster_id="my_redshift", + iam_profile="test", + host="thishostshouldnotexist.test.us-east-1", ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - cluster_identifier='my_redshift', - region='us-east-1', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + region="us-east-1", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', + db_user="root", + password="", + user="", + profile="test", timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_explicit_iam_serverless_with_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - iam_profile='test', - host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com' + method="iam", + iam_profile="test", + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com', - database='redshift', + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", cluster_identifier=None, - region='us-east-2', + region="us-east-2", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', + db_user="root", + password="", + user="", + profile="test", timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_serverless_iam_failure(self): self.config.credentials = self.config.credentials.replace( - method='iam', - iam_profile='test', - host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com' + method="iam", + iam_profile="test", + host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", ) with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com', - database='redshift', + host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", + database="redshift", cluster_identifier=None, - region='us-east-2', + region="us-east-2", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', + db_user="root", + password="", + user="", + profile="test", port=5439, timeout=30, - ) + ) self.assertTrue("'host' must be provided" in context.exception.msg) def test_iam_conn_optionals(self): - profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist', - 'port': 5439, - 'schema': 'public', - 'method': 'iam', - 'cluster_id': 'my_redshift', - 'db_groups': ["my_dbgroup"], - 'autocreate': True, + "outputs": { + "test": { + "type": "redshift", + "dbname": "redshift", + "user": "root", + "host": "thishostshouldnotexist", + "port": 5439, + "schema": "public", + "method": "iam", + "cluster_id": "my_redshift", + "db_groups": ["my_dbgroup"], + "autocreate": True, } }, - 'target': 'test' + "target": "test", } config_from_parts_or_dicts(self.config, profile_cfg) def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate - self.config.credentials.method = 'badmethod' + self.config.credentials.method = "badmethod" with self.assertRaises(FailedToConnectError) as context: connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() - self.assertTrue('badmethod' in context.exception.msg) + self.assertTrue("badmethod" in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.config.credentials = self.config.credentials.replace(method='iam') + self.config.credentials = self.config.credentials.replace(method="iam") with self.assertRaises(FailedToConnectError) as context: connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() @@ -241,171 +241,195 @@ def test_cancel_open_connections_empty(self): def test_cancel_open_connections_master(self): key = self.adapter.connections.get_thread_identifier() - self.adapter.connections.thread_connections[key] = mock_connection('master') + self.adapter.connections.thread_connections[key] = mock_connection("master") self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): - master = mock_connection('master') - model = mock_connection('model') + master = mock_connection("master") + model = mock_connection("model") key = self.adapter.connections.get_thread_identifier() - self.adapter.connections.thread_connections.update({ - key: master, - 1: model, - }) - with mock.patch.object(self.adapter.connections, 'add_query') as add_query: + self.adapter.connections.thread_connections.update( + { + key: master, + 1: model, + } + ) + with mock.patch.object(self.adapter.connections, "add_query") as add_query: query_result = mock.MagicMock() cursor = mock.Mock() cursor.fetchone.return_value = 42 add_query.side_effect = [(None, cursor), (None, query_result)] self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_has_calls([call('select pg_backend_pid()'), call('select pg_terminate_backend(42)')]) + add_query.assert_has_calls( + [call("select pg_backend_pid()"), call("select pg_terminate_backend(42)")] + ) master.handle.get_backend_pid.assert_not_called() def test_dbname_verification_is_case_insensitive(self): # Override adapter settings from setUp() profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'Redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist', - 'pass': 'password', - 'port': 5439, - 'schema': 'public' + "outputs": { + "test": { + "type": "redshift", + "dbname": "Redshift", + "user": "root", + "host": "thishostshouldnotexist", + "pass": "password", + "port": 5439, + "schema": "public", } }, - 'target': 'test' + "target": "test", } project_cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/tmp/dbt/does-not-exist', - 'quoting': { - 'identifier': False, - 'schema': True, + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": True, }, - 'config-version': 2, + "config-version": 2, } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) self.adapter.cleanup_connections() self._adapter = RedshiftAdapter(self.config) - self.adapter.verify_database('redshift') + self.adapter.verify_database("redshift") def test_execute_with_fetch(self): cursor = mock.Mock() table = dbt.clients.agate_helper.empty_table() - with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query: mock_add_query.return_value = ( - None, cursor) # when mock_add_query is called, it will always return None, cursor - with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + None, + cursor, + ) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object(self.adapter.connections, "get_response") as mock_get_response: mock_get_response.return_value = None - with mock.patch.object(self.adapter.connections, - 'get_result_from_cursor') as mock_get_result_from_cursor: + with mock.patch.object( + self.adapter.connections, "get_result_from_cursor" + ) as mock_get_result_from_cursor: mock_get_result_from_cursor.return_value = table self.adapter.connections.execute(sql="select * from test", fetch=True) - mock_add_query.assert_called_once_with('select * from test', False) + mock_add_query.assert_called_once_with("select * from test", False) mock_get_result_from_cursor.assert_called_once_with(cursor) mock_get_response.assert_called_once_with(cursor) def test_execute_without_fetch(self): cursor = mock.Mock() - with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query: mock_add_query.return_value = ( - None, cursor) # when mock_add_query is called, it will always return None, cursor - with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + None, + cursor, + ) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object(self.adapter.connections, "get_response") as mock_get_response: mock_get_response.return_value = None - with mock.patch.object(self.adapter.connections, - 'get_result_from_cursor') as mock_get_result_from_cursor: + with mock.patch.object( + self.adapter.connections, "get_result_from_cursor" + ) as mock_get_result_from_cursor: self.adapter.connections.execute(sql="select * from test2", fetch=False) - mock_add_query.assert_called_once_with('select * from test2', False) + mock_add_query.assert_called_once_with("select * from test2", False) mock_get_result_from_cursor.assert_not_called() mock_get_response.assert_called_once_with(cursor) def test_add_query_with_no_cursor(self): - with mock.patch.object(self.adapter.connections, 'get_thread_connection') as mock_get_thread_connection: + with mock.patch.object( + self.adapter.connections, "get_thread_connection" + ) as mock_get_thread_connection: mock_get_thread_connection.return_value = None - with self.assertRaisesRegex(dbt.exceptions.DbtRuntimeError, - 'Tried to run invalid SQL: on '): + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " + ): self.adapter.connections.add_query(sql="") mock_get_thread_connection.assert_called_once() def test_add_query_success(self): cursor = mock.Mock() - with mock.patch.object(dbt.adapters.redshift.connections.SQLConnectionManager, 'add_query') as mock_add_query: + with mock.patch.object( + dbt.adapters.redshift.connections.SQLConnectionManager, "add_query" + ) as mock_add_query: mock_add_query.return_value = None, cursor - self.adapter.connections.add_query('select * from test3') - mock_add_query.assert_called_once_with('select * from test3', True, bindings=None, abridge_sql_log=False) + self.adapter.connections.add_query("select * from test3") + mock_add_query.assert_called_once_with( + "select * from test3", True, bindings=None, abridge_sql_log=False + ) + class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): rows = [ - ['', 'a1', 'stringval1'], - ['', 'a2', 'stringvalasdfasdfasdfa'], - ['', 'a3', 'stringval3'], + ["", "a1", "stringval1"], + ["", "a2", "stringvalasdfasdfasdfa"], + ["", "a3", "stringval3"], ] agate_table = self._make_table_of(rows, agate.Text) - expected = ['varchar(64)', 'varchar(2)', 'varchar(22)'] + expected = ["varchar(64)", "varchar(2)", "varchar(22)"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_text_type(agate_table, col_idx) == expect def test_convert_number_type(self): rows = [ - ['', '23.98', '-1'], - ['', '12.78', '-2'], - ['', '79.41', '-3'], + ["", "23.98", "-1"], + ["", "12.78", "-2"], + ["", "79.41", "-3"], ] agate_table = self._make_table_of(rows, agate.Number) - expected = ['integer', 'float8', 'integer'] + expected = ["integer", "float8", "integer"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_number_type(agate_table, col_idx) == expect def test_convert_boolean_type(self): rows = [ - ['', 'false', 'true'], - ['', 'false', 'false'], - ['', 'false', 'true'], + ["", "false", "true"], + ["", "false", "false"], + ["", "false", "true"], ] agate_table = self._make_table_of(rows, agate.Boolean) - expected = ['boolean', 'boolean', 'boolean'] + expected = ["boolean", "boolean", "boolean"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_boolean_type(agate_table, col_idx) == expect def test_convert_datetime_type(self): rows = [ - ['', '20190101T01:01:01Z', '2019-01-01 01:01:01'], - ['', '20190102T01:01:01Z', '2019-01-01 01:01:01'], - ['', '20190103T01:01:01Z', '2019-01-01 01:01:01'], + ["", "20190101T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190102T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190103T01:01:01Z", "2019-01-01 01:01:01"], + ] + agate_table = self._make_table_of( + rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime] + ) + expected = [ + "timestamp without time zone", + "timestamp without time zone", + "timestamp without time zone", ] - agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]) - expected = ['timestamp without time zone', 'timestamp without time zone', 'timestamp without time zone'] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_datetime_type(agate_table, col_idx) == expect def test_convert_date_type(self): rows = [ - ['', '2019-01-01', '2019-01-04'], - ['', '2019-01-02', '2019-01-04'], - ['', '2019-01-03', '2019-01-04'], + ["", "2019-01-01", "2019-01-04"], + ["", "2019-01-02", "2019-01-04"], + ["", "2019-01-03", "2019-01-04"], ] agate_table = self._make_table_of(rows, agate.Date) - expected = ['date', 'date', 'date'] + expected = ["date", "date", "date"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_date_type(agate_table, col_idx) == expect def test_convert_time_type(self): # dbt's default type testers actually don't have a TimeDelta at all. rows = [ - ['', '120s', '10s'], - ['', '3m', '11s'], - ['', '1h', '12s'], + ["", "120s", "10s"], + ["", "3m", "11s"], + ["", "1h", "12s"], ] agate_table = self._make_table_of(rows, agate.TimeDelta) - expected = ['varchar(24)', 'varchar(24)', 'varchar(24)'] + expected = ["varchar(24)", "varchar(24)", "varchar(24)"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_time_type(agate_table, col_idx) == expect diff --git a/tests/unit/utils.py b/tests/unit/utils.py index e09b7fc69..f2ca418e3 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -26,21 +26,22 @@ def normalize(path): class Obj: - which = 'blah' + which = "blah" single_threaded = False -def mock_connection(name, state='open'): +def mock_connection(name, state="open"): conn = mock.MagicMock() conn.name = name conn.state = state return conn -def profile_from_dict(profile, profile_name, cli_vars='{}'): +def profile_from_dict(profile, profile_name, cli_vars="{}"): from dbt.config import Profile from dbt.config.renderer import ProfileRenderer from dbt.config.utils import parse_cli_vars + if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) @@ -50,6 +51,7 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'): # flags global. This is a bit of a hack, but it's the best way to do it. from dbt.flags import set_from_args from argparse import Namespace + set_from_args(Namespace(), None) return Profile.from_raw_profile_info( profile, @@ -58,15 +60,16 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'): ) -def project_from_dict(project, profile, packages=None, selectors=None, cli_vars='{}'): +def project_from_dict(project, profile, packages=None, selectors=None, cli_vars="{}"): from dbt.config.renderer import DbtProjectYamlRenderer from dbt.config.utils import parse_cli_vars + if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) renderer = DbtProjectYamlRenderer(profile, cli_vars) - project_root = project.pop('project-root', os.getcwd()) + project_root = project.pop("project-root", os.getcwd()) partial = PartialProject.from_dicts( project_root=project_root, @@ -77,7 +80,7 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars= return partial.render(renderer) -def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars='{}'): +def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars="{}"): from dbt.config import Project, Profile, RuntimeConfig from dbt.config.utils import parse_cli_vars from copy import deepcopy @@ -88,7 +91,7 @@ def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, if isinstance(project, Project): profile_name = project.profile_name else: - profile_name = project.get('profile') + profile_name = project.get("profile") if not isinstance(profile, Profile): profile = profile_from_dict( @@ -108,16 +111,13 @@ def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, args = Obj() args.vars = cli_vars - args.profile_dir = '/dev/null' - return RuntimeConfig.from_parts( - project=project, - profile=profile, - args=args - ) + args.profile_dir = "/dev/null" + return RuntimeConfig.from_parts(project=project, profile=profile, args=args) def inject_plugin(plugin): from dbt.adapters.factory import FACTORY + key = plugin.adapter.type() FACTORY.plugins[key] = plugin @@ -125,8 +125,11 @@ def inject_plugin(plugin): def inject_plugin_for(config): # from dbt.adapters.postgres import Plugin, PostgresAdapter from dbt.adapters.factory import FACTORY + FACTORY.load_plugin(config.credentials.type) - adapter = FACTORY.get_adapter(config) # TODO: there's a get_adaptor function in factory.py, but no method on AdapterContainer + adapter = FACTORY.get_adapter( + config + ) # TODO: there's a get_adaptor function in factory.py, but no method on AdapterContainer return adapter @@ -136,12 +139,14 @@ def inject_adapter(value, plugin): """ inject_plugin(plugin) from dbt.adapters.factory import FACTORY + key = value.type() FACTORY.adapters[key] = value def clear_plugin(plugin): from dbt.adapters.factory import FACTORY + key = plugin.adapter.type() FACTORY.plugins.pop(key, None) FACTORY.adapters.pop(key, None) @@ -184,7 +189,7 @@ def compare_dicts(dict1, dict2): common_keys = set(first_set).intersection(set(second_set)) found_differences = False for key in common_keys: - if dict1[key] != dict2[key] : + if dict1[key] != dict2[key]: print(f"--- --- first dict: {key}: {str(dict1[key])}") print(f"--- --- second dict: {key}: {str(dict2[key])}") found_differences = True @@ -199,7 +204,7 @@ def assert_from_dict(obj, dct, cls=None): cls = obj.__class__ cls.validate(dct) obj_from_dict = cls.from_dict(dct) - if hasattr(obj, 'created_at'): + if hasattr(obj, "created_at"): obj_from_dict.created_at = 1 obj.created_at = 1 assert obj_from_dict == obj @@ -207,10 +212,10 @@ def assert_from_dict(obj, dct, cls=None): def assert_to_dict(obj, dct): obj_to_dict = obj.to_dict(omit_none=True) - if 'created_at' in obj_to_dict: - obj_to_dict['created_at'] = 1 - if 'created_at' in dct: - dct['created_at'] = 1 + if "created_at" in obj_to_dict: + obj_to_dict["created_at"] = 1 + if "created_at" in dct: + dct["created_at"] = 1 assert obj_to_dict == dct @@ -226,10 +231,10 @@ def assert_fails_validation(dct, cls): class TestAdapterConversions(TestCase): - @staticmethod def _get_tester_for(column_type): from dbt.clients import agate_helper + if column_type is agate.TimeDelta: # dbt never makes this! return agate.TimeDelta() @@ -237,10 +242,10 @@ def _get_tester_for(column_type): if isinstance(instance, column_type): return instance - raise ValueError(f'no tester for {column_type}') + raise ValueError(f"no tester for {column_type}") def _make_table_of(self, rows, column_types): - column_names = list(string.ascii_letters[:len(rows[0])]) + column_names = list(string.ascii_letters[: len(rows[0])]) if isinstance(column_types, type): column_types = [self._get_tester_for(column_types) for _ in column_names] else: @@ -251,6 +256,7 @@ def _make_table_of(self, rows, column_types): def load_internal_manifest_macros(config, macro_hook=lambda m: None): from dbt.parser.manifest import ManifestLoader + return ManifestLoader.load_macros(config, macro_hook)