diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 0c9d1b7ed..375e2244a 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -39,6 +39,7 @@ def get_message(self) -> str: class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" + IAMR = "iamr" class UserSSLMode(StrEnum): @@ -104,9 +105,9 @@ def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig": @dataclass class RedshiftCredentials(Credentials): host: str - user: str port: Port method: str = RedshiftConnectionMethod.DATABASE # type: ignore + user: Optional[str] = None # type: ignore password: Optional[str] = None # type: ignore cluster_id: Optional[str] = field( default=None, @@ -226,6 +227,27 @@ def connect(): c.cursor().execute("set role {}".format(self.credentials.role)) return c + elif method == RedshiftConnectionMethod.IAMR: + if not self.credentials.cluster_id and "serverless" not in self.credentials.host: + raise dbt.exceptions.FailedToConnectError( + "Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. " + "'host' must be provided for serverless endpoint." + ) + + def connect(): + logger.debug("Connecting to redshift with IAM based auth...") + c = redshift_connector.connect( + iam=True, + cluster_identifier=self.credentials.cluster_id, + profile=self.credentials.iam_profile, + **kwargs, + ) + if self.credentials.autocommit: + c.autocommit = True + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + else: raise dbt.exceptions.FailedToConnectError( "Invalid 'method' in profile: '{}'".format(method) diff --git a/dbt/include/redshift/profile_template.yml b/dbt/include/redshift/profile_template.yml index 41f33e87e..5f6b0a91a 100644 --- a/dbt/include/redshift/profile_template.yml +++ b/dbt/include/redshift/profile_template.yml @@ -15,6 +15,8 @@ prompts: hide_input: true iam: _fixed_method: iam + iamr: + _fixed_method: iamr dbname: hint: 'default database that dbt will build objects in' schema: diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 1edea565e..f1752dbd3 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -12,7 +12,10 @@ ) from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectError -from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig +from dbt.adapters.redshift.connections import ( + RedshiftConnectMethodFactory, + RedshiftSSLConfig, +) from .utils import ( config_from_parts_or_dicts, mock_connection, @@ -199,6 +202,55 @@ def test_explicit_iam_serverless_with_profile(self): **DEFAULT_SSL_CONFIG, ) + @mock.patch("redshift_connector.connect", Mock()) + def test_explicit_iamr_conn_without_profile(self): + self.config.credentials = self.config.credentials.replace( + method="iamr", + cluster_id="my_redshift", + host="thishostshouldnotexist.test.us-east-1", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + region=None, + timeout=None, + auto_create=False, + db_groups=[], + profile=None, + port=5439, + **DEFAULT_SSL_CONFIG, + ) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_explicit_iamr_conn_with_profile(self): + self.config.credentials = self.config.credentials.replace( + method="iamr", + cluster_id="my_redshift", + iam_profile="test", + host="thishostshouldnotexist.test.us-east-1", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + + redshift_connector.connect.assert_called_once_with( + iam=True, + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + region=None, + auto_create=False, + db_groups=[], + profile="test", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + ) + @mock.patch("redshift_connector.connect", Mock()) @mock.patch("boto3.Session", Mock()) def test_explicit_region(self):