diff --git a/test.env.example b/test.env.example index 6816b4ec2..c743c6aec 100644 --- a/test.env.example +++ b/test.env.example @@ -5,6 +5,7 @@ REDSHIFT_TEST_HOST= REDSHIFT_TEST_PORT= REDSHIFT_TEST_DBNAME= +REDSHIFT_TEST_DBNAME_ALT= REDSHIFT_TEST_USER= REDSHIFT_TEST_PASS= REDSHIFT_TEST_REGION= diff --git a/tests/boundary/conftest.py b/tests/boundary/conftest.py new file mode 100644 index 000000000..6e1705b2b --- /dev/null +++ b/tests/boundary/conftest.py @@ -0,0 +1,41 @@ +from datetime import datetime +import os +import random +from typing import Any, Dict + +import pytest +import redshift_connector + + +@pytest.fixture +def connection(connection_config) -> redshift_connector.Connection: + return redshift_connector.connect(**connection_config) + + +@pytest.fixture +def connection_alt(connection_config) -> redshift_connector.Connection: + config = connection_config.copy() + config.update(database=os.getenv("REDSHIFT_TEST_DBNAME_ALT")) + return redshift_connector.connect(**config) + + +@pytest.fixture +def connection_config() -> Dict[str, Any]: + return { + "user": os.getenv("REDSHIFT_TEST_USER"), + "password": os.getenv("REDSHIFT_TEST_PASS"), + "host": os.getenv("REDSHIFT_TEST_HOST"), + "port": int(os.getenv("REDSHIFT_TEST_PORT")), + "database": os.getenv("REDSHIFT_TEST_DBNAME"), + "region": os.getenv("REDSHIFT_TEST_REGION"), + } + + +@pytest.fixture +def schema_name(request) -> str: + runtime = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0) + runtime_s = int(runtime.total_seconds()) + runtime_ms = runtime.microseconds + random_int = random.randint(0, 9999) + file_name = request.module.__name__.split(".")[-1] + return f"test_{runtime_s}{runtime_ms}{random_int:04}_{file_name}" diff --git a/tests/boundary/test_redshift_connector.py b/tests/boundary/test_redshift_connector.py new file mode 100644 index 000000000..90cf28b70 --- /dev/null +++ b/tests/boundary/test_redshift_connector.py @@ -0,0 +1,61 @@ +import os + +import pytest + + +@pytest.fixture(autouse=True) +def setup(connection, connection_alt, schema_name) -> str: + # create the same table in two different databases + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id") + with connection_alt.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id") + + yield schema_name + + # drop both test schemas + with connection_alt.cursor() as cursor: + cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") + with connection.cursor() as cursor: + cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") + + +def test_columns_in_relation(connection, schema_name): + # we're specifically running this query from the default database + # we're expecting to get both tables, the one in the default database and the one in the alt database + with connection.cursor() as cursor: + columns = cursor.get_columns(schema_pattern=schema_name, tablename_pattern="cross_db") + + # we should have the same table in both databases + assert len(columns) == 2 + + databases = set() + for column in columns: + ( + database, + schema, + table, + name, + type_code, + type_name, + precision, + _, + scale, + *_, + ) = column + databases.add(database) + assert schema_name == schema_name + assert table == "cross_db" + assert name == "id" + assert type_code == 2 + assert type_name == "numeric" + assert precision == 3 + assert scale == 2 + + # only the databases are different + assert databases == { + os.getenv("REDSHIFT_TEST_DBNAME"), + os.getenv("REDSHIFT_TEST_DBNAME_ALT"), + } diff --git a/tests/functional/test_columns_in_relation.py b/tests/functional/test_columns_in_relation.py new file mode 100644 index 000000000..581a60a94 --- /dev/null +++ b/tests/functional/test_columns_in_relation.py @@ -0,0 +1,41 @@ +from dbt.tests.util import get_connection, run_dbt +import pytest + + +MY_CROSS_DB_SOURCES = """ +version: 2 +sources: + - name: ci + schema: adapter + tables: + - name: cross_db + - name: ci_alt + database: ci_alt + schema: adapter + tables: + - name: cross_db +""" + + +class TestCrossDatabase: + """ + This addresses https://github.com/dbt-labs/dbt-redshift/issues/736 + """ + + @pytest.fixture(scope="class") + def models(self): + my_model = """ + select '{{ adapter.get_columns_in_relation(source('ci', 'cross_db')) }}' as columns + union all + select '{{ adapter.get_columns_in_relation(source('ci_alt', 'cross_db')) }}' as columns + """ + return { + "sources.yml": MY_CROSS_DB_SOURCES, + "my_model.sql": my_model, + } + + def test_columns_in_relation(self, project): + run_dbt(["run"]) + with get_connection(project.adapter, "_test"): + records = project.run_sql(f"select * from {project.test_schema}.my_model", fetch=True) + assert len(records) == 2