diff --git a/.changes/unreleased/Features-20230730-185656.yaml b/.changes/unreleased/Features-20230730-185656.yaml new file mode 100644 index 000000000..bdaa2f6f9 --- /dev/null +++ b/.changes/unreleased/Features-20230730-185656.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support all Redshift Connection Methods +time: 2023-07-30T18:56:56.9788512+02:00 +custom: + Author: christopherscholz + Issue: "332" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 500de430f..0e5a2b6bd 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,7 +1,7 @@ import re from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType, Tuple, Union, Optional, List +from typing import List, Optional, Tuple, Union from dataclasses import dataclass, field import agate @@ -12,7 +12,7 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials from dbt.contracts.util import Replaceable -from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum, ValidationError +from dbt.dataclass_schema import dbtClassMixin, StrEnum, ValidationError from dbt.events import AdapterLogger from dbt.exceptions import DbtRuntimeError, CompilationError import dbt.flags @@ -30,22 +30,21 @@ def get_message(self) -> str: return msg -logger = AdapterLogger("Redshift") - - -drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore - +class UniqueFieldError(CompilationError): + def __init__(self, exc: ValidationError): + self.exc = exc + super().__init__(msg=self.get_message()) -IAMDuration = NewType("IAMDuration", int) + def get_message(self) -> str: + validator_msg = self.validator_error_message(self.exc) + msg = f"Could not parse unique field: {validator_msg}" + return msg -class IAMDurationEncoder(FieldEncoder): - @property - def json_schema(self): - return {"type": "integer", "minimum": 0, "maximum": 65535} +logger = AdapterLogger("Redshift") -dbtClassMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()}) +drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore class RedshiftConnectionMethod(StrEnum): @@ -113,30 +112,121 @@ def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig": return redshift_ssl +NO_DATABASE = "none" + + +@dataclass +class RedshiftDatabase(dbtClassMixin, Replaceable): # type: ignore + database: Optional[str] = None + + @classmethod + def parse(cls, database: str) -> "RedshiftDatabase": + raw_redshift_database = {"database": database if database != NO_DATABASE else None} + redshift_database = cls.from_dict(raw_redshift_database) + + return redshift_database + + +@dataclass +class RedshiftUniqueField(dbtClassMixin, Replaceable): # type: ignore + unique_field: str + + @classmethod + def parse( + cls, host: Optional[str] = None, cluster_identifier: Optional[str] = None + ) -> "RedshiftUniqueField": + try: + raw_redshift_unique_field = {"unique_field": host if host else cluster_identifier} + cls.validate(raw_redshift_unique_field) + except ValidationError as exc: + raise UniqueFieldError(exc) + + redshift_unique_field = cls.from_dict(raw_redshift_unique_field) + + return redshift_unique_field + + @dataclass class RedshiftCredentials(Credentials): - host: str - user: str - port: Port - method: str = RedshiftConnectionMethod.DATABASE # type: ignore - password: Optional[str] = None # type: ignore - cluster_id: Optional[str] = field( - default=None, - metadata={"description": "If using IAM auth, the name of the cluster"}, - ) - iam_profile: Optional[str] = None - autocreate: bool = False - db_groups: List[str] = field(default_factory=list) - ra3_node: Optional[bool] = False - connect_timeout: Optional[int] = None - role: Optional[str] = None - sslmode: Optional[UserSSLMode] = field(default_factory=UserSSLMode.default) + # functional dbt fields + # schema -> already provided by dbt.contracts.connection.Credentials + + # connection flow fields retries: int = 1 - region: Optional[str] = None + method: Optional[str] = None # for backwards compatibility + ra3_node: Optional[bool] = False + + # session specific fields # opt-in by default per team deliberation on https://peps.python.org/pep-0249/#autocommit autocommit: Optional[bool] = True + role: Optional[str] = None - _ALIASES = {"dbname": "database", "pass": "password"} + # connection specific fields based on: + # https://github.com/aws/amazon-redshift-python-driver/blob/v2.0.913/redshift_connector/__init__.py + user: Optional[str] = None + # databse already provided by dbt.contracts.connection.Credentials as it is a functional dbt requirement + password: Optional[str] = None + port: Optional[Port] = None + host: Optional[str] = None + source_address: Optional[str] = None + unix_sock: Optional[str] = None + sslmode: Optional[UserSSLMode] = field(default_factory=UserSSLMode.default) + timeout: Optional[int] = None + max_prepared_statements: Optional[int] = None + tcp_keepalive: Optional[bool] = None + application_name: Optional[str] = None + replication: Optional[str] = None + idp_host: Optional[str] = None + db_user: Optional[str] = None + app_id: Optional[str] = None + app_name: Optional[str] = None + preferred_role: Optional[str] = None + principal_arn: Optional[str] = None + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + session_token: Optional[str] = None + profile: Optional[str] = None + credentials_provider: Optional[str] = None + region: Optional[str] = None + cluster_identifier: Optional[str] = None + iam: Optional[bool] = None + client_id: Optional[str] = None + idp_tenant: Optional[str] = None + client_secret: Optional[str] = None + partner_sp_id: Optional[str] = None + idp_response_timeout: Optional[int] = None + listen_port: Optional[int] = None + login_url: Optional[str] = None + auto_create: Optional[bool] = False + db_groups: List[str] = field(default_factory=list) + force_lowercase: Optional[bool] = None + allow_db_user_override: Optional[bool] = None + client_protocol_version: Optional[int] = None + database_metadata_current_db_only: Optional[bool] = None + ssl_insecure: Optional[bool] = None + web_identity_token: Optional[str] = None + role_session_name: Optional[str] = None + role_arn: Optional[str] = None + iam_disable_cache: Optional[bool] = None + auth_profile: Optional[str] = None + endpoint_url: Optional[str] = None + provider_name: Optional[str] = None + scope: Optional[str] = None + numeric_to_float: Optional[bool] = False + is_serverless: Optional[bool] = False + serverless_acct_id: Optional[str] = None + serverless_work_group: Optional[str] = None + group_federation: Optional[bool] = None + + _ALIASES = { + "dbname": "database", + "pass": "password", + # for backwards compatibility + "auto_create": "autocreate", + "cluster_identifier": "cluster_id", + "connect_timeout": "timeout", + "iam_profile": "profile", + } @property def type(self): @@ -144,106 +234,71 @@ def type(self): def _connection_keys(self): return ( - "host", + "schema", + "autocommit", + "role", + "method", + "ra3_node", + "retries", + "database", "user", + "password", "port", - "database", - "method", - "cluster_id", - "iam_profile", - "schema", - "sslmode", - "region", + "host", + "source_address", + "unix_sock", "sslmode", + "timeout", + "max_prepared_statements", + "tcp_keepalive", + "application_name", + "replication", + "idp_host", + "db_user", + "app_id", + "app_name", + "preferred_role", + "principal_arn", + "access_key_id", + "secret_access_key", + "session_token", + "profile", + "credentials_provider", "region", - "iam_profile", - "autocreate", + "cluster_identifier", + "iam", + "client_id", + "idp_tenant", + "client_secret", + "partner_sp_id", + "idp_response_timeout", + "listen_port", + "login_url", + "auto_create", "db_groups", - "ra3_node", - "connect_timeout", - "role", - "retries", - "autocommit", + "force_lowercase", + "allow_db_user_override", + "client_protocol_version", + "database_metadata_current_db_only", + "ssl_insecure", + "web_identity_token", + "role_session_name", + "role_arn", + "iam_disable_cache", + "auth_profile", + "endpoint_url", + "provider_name", + "scope", + "numeric_to_float", + "is_serverless", + "serverless_acct_id", + "serverless_work_group", + "group_federation", ) @property def unique_field(self) -> str: - return self.host - - -class RedshiftConnectMethodFactory: - credentials: RedshiftCredentials - - def __init__(self, credentials): - self.credentials = credentials - - def get_connect_method(self): - method = self.credentials.method - kwargs = { - "host": self.credentials.host, - "database": self.credentials.database, - "port": int(self.credentials.port) if self.credentials.port else int(5439), - "auto_create": self.credentials.autocreate, - "db_groups": self.credentials.db_groups, - "region": self.credentials.region, - "timeout": self.credentials.connect_timeout, - } - - redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode) - kwargs.update(redshift_ssl_config.to_dict()) - - # Support missing 'method' for backwards compatibility - if method == RedshiftConnectionMethod.DATABASE or method is None: - # this requirement is really annoying to encode into json schema, - # so validate it here - if self.credentials.password is None: - raise dbt.exceptions.FailedToConnectError( - "'password' field is required for 'database' credentials" - ) - - def connect(): - logger.debug("Connecting to redshift with username/password based auth...") - c = redshift_connector.connect( - user=self.credentials.user, - password=self.credentials.password, - **kwargs, - ) - if self.credentials.autocommit: - c.autocommit = True - if self.credentials.role: - c.cursor().execute("set role {}".format(self.credentials.role)) - return c - - elif method == RedshiftConnectionMethod.IAM: - if not self.credentials.cluster_id and "serverless" not in self.credentials.host: - raise dbt.exceptions.FailedToConnectError( - "Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. " - "'host' must be provided for serverless endpoint." - ) - - def connect(): - logger.debug("Connecting to redshift with IAM based auth...") - c = redshift_connector.connect( - iam=True, - db_user=self.credentials.user, - password="", - user="", - cluster_identifier=self.credentials.cluster_id, - profile=self.credentials.iam_profile, - **kwargs, - ) - if self.credentials.autocommit: - c.autocommit = True - if self.credentials.role: - c.cursor().execute("set role {}".format(self.credentials.role)) - return c - - else: - raise dbt.exceptions.FailedToConnectError( - "Invalid 'method' in profile: '{}'".format(method) - ) - - return connect + return RedshiftUniqueField.parse(self.host, self.cluster_identifier).unique_field class RedshiftConnectionManager(SQLConnectionManager): @@ -326,7 +381,6 @@ def open(cls, connection): return connection credentials = connection.credentials - connect_method_factory = RedshiftConnectMethodFactory(credentials) def exponential_backoff(attempt: int): return attempt * attempt @@ -337,9 +391,96 @@ def exponential_backoff(attempt: int): redshift_connector.DataError, ] + kwargs = { + k: v + for k, v in { + "user": credentials.user, + "password": credentials.password, + "port": int(credentials.port) if credentials.port else None, + "host": credentials.host, + "source_address": credentials.source_address, + "unix_sock": credentials.unix_sock, + "timeout": credentials.timeout, + "max_prepared_statements": credentials.max_prepared_statements, + "tcp_keepalive": credentials.tcp_keepalive, + "application_name": credentials.application_name, + "replication": credentials.replication, + "idp_host": credentials.idp_host, + "db_user": credentials.db_user, + "app_id": credentials.app_id, + "app_name": credentials.app_name, + "preferred_role": credentials.preferred_role, + "principal_arn": credentials.principal_arn, + "access_key_id": credentials.access_key_id, + "secret_access_key": credentials.secret_access_key, + "session_token": credentials.session_token, + "profile": credentials.profile, + "credentials_provider": credentials.credentials_provider, + "region": credentials.region, + "cluster_identifier": credentials.cluster_identifier, + "iam": credentials.iam, + "client_id": credentials.client_id, + "idp_tenant": credentials.idp_tenant, + "client_secret": credentials.client_secret, + "partner_sp_id": credentials.partner_sp_id, + "idp_response_timeout": credentials.idp_response_timeout, + "listen_port": credentials.listen_port, + "login_url": credentials.login_url, + "auto_create": credentials.auto_create, + "db_groups": credentials.db_groups, + "force_lowercase": credentials.force_lowercase, + "allow_db_user_override": credentials.allow_db_user_override, + "client_protocol_version": credentials.client_protocol_version, + "database_metadata_current_db_only": credentials.database_metadata_current_db_only, + "ssl_insecure": credentials.ssl_insecure, + "web_identity_token": credentials.web_identity_token, + "role_session_name": credentials.role_session_name, + "role_arn": credentials.role_arn, + "iam_disable_cache": credentials.iam_disable_cache, + "auth_profile": credentials.auth_profile, + "endpoint_url": credentials.endpoint_url, + "provider_name": credentials.provider_name, + "scope": credentials.scope, + "numeric_to_float": credentials.numeric_to_float, + "is_serverless": credentials.is_serverless, + "serverless_acct_id": credentials.serverless_acct_id, + "serverless_work_group": credentials.serverless_work_group, + "group_federation": credentials.group_federation, + }.items() + if v is not None + } + + # for redshift_connector database is not required and can be provided by an authentication profile + redshift_database = RedshiftDatabase.parse(credentials.database) + kwargs.update(redshift_database.to_dict()) + + redshift_ssl_config = RedshiftSSLConfig.parse(credentials.sslmode) + kwargs.update(redshift_ssl_config.to_dict()) + + # for backwards compatibility + if credentials.method == RedshiftConnectionMethod.IAM: + kwargs.update( + { + "iam": True, + "db_user": credentials.user, + "user": "", + "password": "", + } + ) + + def connect(): + c = redshift_connector.connect( + **kwargs, + ) + if credentials.autocommit: + c.autocommit = True + if credentials.role: + c.cursor().execute("set role {}".format(credentials.role)) + return c + return cls.retry_connection( connection, - connect=connect_method_factory.get_connect_method(), + connect, logger=logger, retry_limit=credentials.retries, retry_timeout=exponential_backoff, diff --git a/setup.py b/setup.py index 7b4856351..9eb85dc6a 100644 --- a/setup.py +++ b/setup.py @@ -84,8 +84,6 @@ def _core_version(plugin_version: str = _plugin_version()) -> str: install_requires=[ f"dbt-core~={_core_version()}", f"dbt-postgres~={_core_version()}", - "boto3~=1.26.157", - # dbt-redshift depends deeply on this package. it does not follow SemVer, therefore there have been breaking changes in previous patch releases # Pin to the patch or minor version, and bump in each new minor version of dbt-redshift. "redshift-connector==2.0.913", # installed via dbt-core but referenced directly; don't pin to avoid version conflicts with dbt-core diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index c31366a1e..36ddf51b0 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -11,8 +11,7 @@ Plugin as RedshiftPlugin, ) from dbt.clients import agate_helper -from dbt.exceptions import FailedToConnectError -from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig +from dbt.adapters.redshift.connections import RedshiftSSLConfig from .utils import ( config_from_parts_or_dicts, mock_connection, @@ -75,8 +74,8 @@ def test_implicit_database_conn(self): port=5439, auto_create=False, db_groups=[], - timeout=None, - region=None, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @@ -94,8 +93,8 @@ def test_explicit_region_with_database_conn(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @@ -103,7 +102,7 @@ def test_explicit_region_with_database_conn(self): def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", - cluster_id="my_redshift", + cluster_identifier="my_redshift", host="thishostshouldnotexist.test.us-east-1", ) connection = self.adapter.acquire_connection("dummy") @@ -116,18 +115,17 @@ def test_explicit_iam_conn_without_profile(self): password="", user="", cluster_identifier="my_redshift", - region=None, - timeout=None, auto_create=False, db_groups=[], - profile=None, port=5439, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @mock.patch("redshift_connector.connect", Mock()) def test_conn_timeout_30(self): - self.config.credentials = self.config.credentials.replace(connect_timeout=30) + self.config.credentials = self.config.credentials.replace(timeout=30) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -138,18 +136,18 @@ def test_conn_timeout_30(self): port=5439, auto_create=False, db_groups=[], - region=None, timeout=30, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", - cluster_id="my_redshift", - iam_profile="test", + cluster_identifier="my_redshift", + profile="test", host="thishostshouldnotexist.test.us-east-1", ) connection = self.adapter.acquire_connection("dummy") @@ -160,24 +158,23 @@ def test_explicit_iam_conn_with_profile(self): host="thishostshouldnotexist.test.us-east-1", database="redshift", cluster_identifier="my_redshift", - region=None, auto_create=False, db_groups=[], db_user="root", password="", user="", profile="test", - timeout=None, port=5439, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) def test_explicit_iam_serverless_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", - iam_profile="test", + profile="test", host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", ) connection = self.adapter.acquire_connection("dummy") @@ -186,26 +183,24 @@ def test_explicit_iam_serverless_with_profile(self): iam=True, host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", database="redshift", - cluster_identifier=None, - region=None, auto_create=False, db_groups=[], db_user="root", password="", user="", profile="test", - timeout=None, port=5439, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) def test_explicit_region(self): # Successful test self.config.credentials = self.config.credentials.replace( method="iam", - iam_profile="test", + profile="test", host="doesnotexist.1233.redshift-serverless.amazonaws.com", region="us-east-2", ) @@ -215,7 +210,6 @@ def test_explicit_region(self): iam=True, host="doesnotexist.1233.redshift-serverless.amazonaws.com", database="redshift", - cluster_identifier=None, region="us-east-2", auto_create=False, db_groups=[], @@ -223,71 +217,12 @@ def test_explicit_region(self): password="", user="", profile="test", - timeout=None, port=5439, + numeric_to_float=False, + is_serverless=False, **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) - def test_explicit_region_failure(self): - # Failure test with no region - self.config.credentials = self.config.credentials.replace( - method="iam", - iam_profile="test", - host="doesnotexist.1233_no_region", - region=None, - ) - - with self.assertRaises(dbt.exceptions.FailedToConnectError): - connection = self.adapter.acquire_connection("dummy") - connection.handle - redshift_connector.connect.assert_called_once_with( - iam=True, - host="doesnotexist.1233_no_region", - database="redshift", - cluster_identifier=None, - auto_create=False, - db_groups=[], - db_user="root", - password="", - user="", - profile="test", - timeout=None, - port=5439, - **DEFAULT_SSL_CONFIG, - ) - - @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) - def test_explicit_invalid_region(self): - # Invalid region test - self.config.credentials = self.config.credentials.replace( - method="iam", - iam_profile="test", - host="doesnotexist.1233_no_region.us-not-a-region-1", - region=None, - ) - - with self.assertRaises(dbt.exceptions.FailedToConnectError): - connection = self.adapter.acquire_connection("dummy") - connection.handle - redshift_connector.connect.assert_called_once_with( - iam=True, - host="doesnotexist.1233_no_region", - database="redshift", - cluster_identifier=None, - auto_create=False, - db_groups=[], - db_user="root", - password="", - user="", - profile="test", - timeout=None, - port=5439, - **DEFAULT_SSL_CONFIG, - ) - @mock.patch("redshift_connector.connect", Mock()) def test_sslmode_disable(self): self.config.credentials.sslmode = "disable" @@ -301,10 +236,10 @@ def test_sslmode_disable(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, ssl=False, sslmode=None, + numeric_to_float=False, + is_serverless=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -320,10 +255,10 @@ def test_sslmode_allow(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, ssl=True, sslmode="verify-ca", + numeric_to_float=False, + is_serverless=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -339,10 +274,10 @@ def test_sslmode_verify_full(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, ssl=True, sslmode="verify-full", + numeric_to_float=False, + is_serverless=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -358,10 +293,10 @@ def test_sslmode_verify_ca(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, ssl=True, sslmode="verify-ca", + numeric_to_float=False, + is_serverless=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -377,41 +312,12 @@ def test_sslmode_prefer(self): port=5439, auto_create=False, db_groups=[], - region=None, - timeout=None, ssl=True, sslmode="verify-ca", + numeric_to_float=False, + is_serverless=False, ) - @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) - def test_serverless_iam_failure(self): - self.config.credentials = self.config.credentials.replace( - method="iam", - iam_profile="test", - host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", - ) - with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: - connection = self.adapter.acquire_connection("dummy") - connection.handle - redshift_connector.connect.assert_called_once_with( - iam=True, - host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", - database="redshift", - cluster_identifier=None, - region=None, - auto_create=False, - db_groups=[], - db_user="root", - password="", - user="", - profile="test", - port=5439, - timeout=None, - **DEFAULT_SSL_CONFIG, - ) - self.assertTrue("'host' must be provided" in context.exception.msg) - def test_iam_conn_optionals(self): profile_cfg = { "outputs": { @@ -433,22 +339,6 @@ def test_iam_conn_optionals(self): config_from_parts_or_dicts(self.config, profile_cfg) - def test_invalid_auth_method(self): - # we have to set method this way, otherwise it won't validate - self.config.credentials.method = "badmethod" - with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) - connect_method_factory.get_connect_method() - self.assertTrue("badmethod" in context.exception.msg) - - def test_invalid_iam_no_cluster_id(self): - self.config.credentials = self.config.credentials.replace(method="iam") - with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) - connect_method_factory.get_connect_method() - - self.assertTrue("'cluster_id' must be provided" in context.exception.msg) - def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)