diff --git a/.changes/unreleased/Features-20240826-123954.yaml b/.changes/unreleased/Features-20240826-123954.yaml new file mode 100644 index 000000000..209bf9dc7 --- /dev/null +++ b/.changes/unreleased/Features-20240826-123954.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Remove `pg_catalog` from metadata queries +time: 2024-08-26T12:39:54.481505-04:00 +custom: + Author: mikealfare, jiezhen-chen + Issue: "555" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index d3fbcafea..8e7ae36d2 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from typing import Any, Callable, Dict, Tuple, Union, Optional, List, TYPE_CHECKING from dataclasses import dataclass, field +import time import sqlparse import redshift_connector @@ -12,10 +13,14 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.events.types import SQLQuery, SQLQueryStatus from dbt_common.contracts.util import Replaceable from dbt_common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError +from dbt_common.events.contextvars import get_node_info +from dbt_common.events.functions import fire_event from dbt_common.helper_types import Port from dbt_common.exceptions import DbtRuntimeError, CompilationError, DbtDatabaseError +from dbt_common.utils import cast_to_str if TYPE_CHECKING: # Indirectly imported via agate_helper, which is lazy loaded further downfile. @@ -460,3 +465,51 @@ def _initialize_sqlparse_lexer(): if hasattr(Lexer, "get_default_instance"): Lexer.get_default_instance() + + def columns_in_relation(self, relation) -> List[Dict[str, Any]]: + connection = self.get_thread_connection() + + fire_event( + SQLQuery( + conn_name=cast_to_str(connection.name), + sql=f"call redshift_connector.Connection.get_columns({relation.database}, {relation.schema}, {relation.identifier})", + node_info=get_node_info(), + ) + ) + + pre = time.perf_counter() + + cursor = connection.handle.cursor() + columns = cursor.get_columns( + catalog=relation.database, + schema_pattern=relation.schema, + tablename_pattern=relation.identifier, + ) + + fire_event( + SQLQueryStatus( + status=str(self.get_response(cursor)), + elapsed=time.perf_counter() - pre, + node_info=get_node_info(), + ) + ) + + return [self._parse_column_results(column) for column in columns] + + @staticmethod + def _parse_column_results(record: Tuple[Any, ...]) -> Dict[str, Any]: + _, _, _, column_name, dtype_code, dtype_name, column_size, _, decimals, *_ = record + + char_dtypes = [1, 12] + num_dtypes = [2, 3, 4, 5, 6, 7, 8, -5, 2003] + + if dtype_code in char_dtypes: + return {"column": column_name, "dtype": dtype_name, "char_size": column_size} + elif dtype_code in num_dtypes: + return { + "column": column_name, + "dtype": dtype_name, + "numeric_precision": column_size, + "numeric_scale": decimals, + } + return {"column": column_name, "dtype": dtype_name, "char_size": column_size} diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index d498685ed..da86ab7cb 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -1,9 +1,12 @@ import os from dataclasses import dataclass + +from dbt_common.behavior_flags import BehaviorFlag from dbt_common.contracts.constraints import ConstraintType -from typing import Optional, Set, Any, Dict, Type, TYPE_CHECKING +from typing import Optional, Set, Any, Dict, Type, TYPE_CHECKING, List from collections import namedtuple from dbt.adapters.base import PythonJobHelper +from dbt.adapters.base.column import Column from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport from dbt.adapters.base.meta import available from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support @@ -65,6 +68,10 @@ class RedshiftAdapter(SQLAdapter): } ) + @property + def _behavior_flags(self) -> List[BehaviorFlag]: + return [{"name": "restrict_direct_pg_catalog_access", "default": False}] + @classmethod def date_function(cls): return "getdate()" @@ -87,6 +94,12 @@ def drop_relation(self, relation): with self.connections.fresh_transaction(): return super().drop_relation(relation) + def get_columns_in_relation(self, relation) -> List[Column]: + if self.behavior.restrict_direct_pg_catalog_access: + column_configs = self.connections.columns_in_relation(relation) + return [Column(**column) for column in column_configs] + return super().get_columns_in_relation(relation) + @classmethod def convert_text_type(cls, agate_table: "agate.Table", col_idx): column = agate_table.columns[col_idx] diff --git a/dev-requirements.txt b/dev-requirements.txt index 52c26c936..199cfaecd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # install latest changes in dbt-core + dbt-postgres -git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core +git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core git+https://github.com/dbt-labs/dbt-adapters.git git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter git+https://github.com/dbt-labs/dbt-common.git diff --git a/tests/boundary/__init__.py b/tests/boundary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/boundary/conftest.py b/tests/boundary/conftest.py new file mode 100644 index 000000000..402fa2d66 --- /dev/null +++ b/tests/boundary/conftest.py @@ -0,0 +1,28 @@ +from datetime import datetime +import os +import random + +import pytest +import redshift_connector + + +@pytest.fixture +def connection() -> redshift_connector.Connection: + return redshift_connector.connect( + user=os.getenv("REDSHIFT_TEST_USER"), + password=os.getenv("REDSHIFT_TEST_PASS"), + host=os.getenv("REDSHIFT_TEST_HOST"), + port=int(os.getenv("REDSHIFT_TEST_PORT")), + database=os.getenv("REDSHIFT_TEST_DBNAME"), + region=os.getenv("REDSHIFT_TEST_REGION"), + ) + + +@pytest.fixture +def schema_name(request) -> str: + runtime = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0) + runtime_s = int(runtime.total_seconds()) + runtime_ms = runtime.microseconds + random_int = random.randint(0, 9999) + file_name = request.module.__name__.split(".")[-1] + return f"test_{runtime_s}{runtime_ms}{random_int:04}_{file_name}" diff --git a/tests/boundary/test_redshift_connector.py b/tests/boundary/test_redshift_connector.py new file mode 100644 index 000000000..200d0cccf --- /dev/null +++ b/tests/boundary/test_redshift_connector.py @@ -0,0 +1,43 @@ +import pytest + + +@pytest.fixture +def schema(connection, schema_name) -> str: + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") + yield schema_name + with connection.cursor() as cursor: + cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE") + + +def test_columns_in_relation(connection, schema): + table = "cross_db" + with connection.cursor() as cursor: + cursor.execute(f"CREATE TABLE {schema}.{table} as select 3.14 as id") + columns = cursor.get_columns( + schema_pattern=schema, + tablename_pattern=table, + ) + + assert len(columns) == 1 + column = columns[0] + + ( + database_name, + schema_name, + table_name, + column_name, + type_code, + type_name, + precision, + _, + scale, + *_, + ) = column + assert schema_name == schema + assert table_name == table + assert column_name == "id" + assert type_code == 2 + assert type_name == "numeric" + assert precision == 3 + assert scale == 2 diff --git a/tests/functional/__init__.py b/tests/functional/__init__.py index 30e204d08..5cfdf5d2d 100644 --- a/tests/functional/__init__.py +++ b/tests/functional/__init__.py @@ -1 +1 @@ -# provides namespacing for test discovery +# supports namespacing during test discovery diff --git a/tests/functional/test_columns_in_relation.py b/tests/functional/test_columns_in_relation.py new file mode 100644 index 000000000..60aeaa2aa --- /dev/null +++ b/tests/functional/test_columns_in_relation.py @@ -0,0 +1,59 @@ +from dbt.adapters.base import Column +from dbt.tests.util import run_dbt +import pytest + +from dbt.adapters.redshift import RedshiftRelation + + +class ColumnsInRelation: + + @pytest.fixture(scope="class") + def models(self): + return {"my_model.sql": "select 1.23 as my_num, 'a' as my_char"} + + @pytest.fixture(scope="class", autouse=True) + def setup(self, project): + run_dbt(["run"]) + + @pytest.fixture(scope="class") + def expected_columns(self): + return [] + + def test_columns_in_relation(self, project, expected_columns): + my_relation = RedshiftRelation.create( + database=project.database, + schema=project.test_schema, + identifier="my_model", + type=RedshiftRelation.View, + ) + with project.adapter.connection_named("_test"): + actual_columns = project.adapter.get_columns_in_relation(my_relation) + assert actual_columns == expected_columns + + +class TestColumnsInRelationBehaviorFlagOff(ColumnsInRelation): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {}} + + @pytest.fixture(scope="class") + def expected_columns(self): + # the SDK query returns "varchar" whereas our custom query returns "character varying" + return [ + Column(column="my_num", dtype="numeric", numeric_precision=3, numeric_scale=2), + Column(column="my_char", dtype="character varying", char_size=1), + ] + + +class TestColumnsInRelationBehaviorFlagOn(ColumnsInRelation): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"restrict_direct_pg_catalog_access": True}} + + @pytest.fixture(scope="class") + def expected_columns(self): + # the SDK query returns "varchar" whereas our custom query returns "character varying" + return [ + Column(column="my_num", dtype="numeric", numeric_precision=3, numeric_scale=2), + Column(column="my_char", dtype="varchar", char_size=1), + ]