diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 521bcff86..ac12adf2a 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -37,10 +37,17 @@ def get_message(self) -> str: logger = AdapterLogger("Redshift") +class IdentityCenterTokenType(StrEnum): + ACCESS_TOKEN = "ACCESS_TOKEN" + EXT_JWT = "EXT_JWT" + + class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" IAM_ROLE = "iam_role" + IAM_IDENTITY_CENTER_BROWSER = "iam_idc_browser" + IAM_IDENTITY_CENTER_TOKEN = "iam_idc_token" class UserSSLMode(StrEnum): @@ -128,6 +135,22 @@ class RedshiftCredentials(Credentials): access_key_id: Optional[str] = None secret_access_key: Optional[str] = None + # + # IAM identity center methods + # + + # browser + credentials_provider: Optional[str] = None + idc_region: Optional[str] = None + issuer_url: Optional[str] = None + listen_port: int = 7890 + idc_client_display_name: Optional[str] = "Amazon Redshift driver" + idp_response_timeout: int = 60 + + # token + token: Optional[str] = None + token_type: Optional[str] = None + _ALIASES = {"dbname": "database", "pass": "password"} @property @@ -163,131 +186,181 @@ def unique_field(self) -> str: return self.host -class RedshiftConnectMethodFactory: - credentials: RedshiftCredentials +def get_connection_method( + credentials: RedshiftCredentials, +) -> Callable[[], redshift_connector.Connection]: + # + # Helper Methods + # + def __assert_required_fields(credentials, required_fields, method_name): + missing_fields = [ + field for field in required_fields if getattr(credentials, field, None) is None + ] + if missing_fields: + fields_str = "', '".join(missing_fields) + raise FailedToConnectError( + f"'{fields_str}' field(s) are required for '{method_name}' credentials method" + ) - def __init__(self, credentials) -> None: - self.credentials = credentials + def __base_kwargs(credentials) -> Dict[str, Any]: + redshift_ssl_config = RedshiftSSLConfig.parse(credentials.sslmode).to_dict() + return { + "host": credentials.host, + "port": int(credentials.port) if credentials.port else 5439, + "database": credentials.database, + "region": credentials.region, + "auto_create": credentials.autocreate, + "db_groups": credentials.db_groups, + "timeout": credentials.connect_timeout, + **redshift_ssl_config, + } - def get_connect_method(self) -> Callable[[], redshift_connector.Connection]: + def __iam_kwargs(credentials) -> Dict[str, Any]: - # Support missing 'method' for backwards compatibility - method = self.credentials.method or RedshiftConnectionMethod.DATABASE - if method == RedshiftConnectionMethod.DATABASE: - kwargs = self._database_kwargs - elif method == RedshiftConnectionMethod.IAM: - kwargs = self._iam_user_kwargs - elif method == RedshiftConnectionMethod.IAM_ROLE: - kwargs = self._iam_role_kwargs + if "serverless" in credentials.host: + cluster_identifier = None + elif credentials.cluster_id: + cluster_identifier = credentials.cluster_id else: - raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'") + raise FailedToConnectError( + "Failed to use IAM method:" + " 'cluster_id' must be provided for provisioned cluster" + " 'host' must be provided for serverless endpoint" + ) + + iam_specific_kwargs = { + "iam": True, + "user": "", + "password": "", + "cluster_identifier": cluster_identifier, + } - def connect() -> redshift_connector.Connection: - c = redshift_connector.connect(**kwargs) - if self.credentials.autocommit: - c.autocommit = True - if self.credentials.role: - c.cursor().execute(f"set role {self.credentials.role}") - return c + return __base_kwargs(credentials) | iam_specific_kwargs - return connect + def __database_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'database' credentials method") - @property - def _database_kwargs(self) -> Dict[str, Any]: - logger.debug("Connecting to redshift with 'database' credentials method") - kwargs = self._base_kwargs - - if self.credentials.user and self.credentials.password: - kwargs.update( - user=self.credentials.user, - password=self.credentials.password, - ) - else: - raise FailedToConnectError( - "'user' and 'password' fields are required for 'database' credentials method" - ) + __assert_required_fields(credentials, ["user", "password"], "database") - return kwargs + db_credentials = { + "user": credentials.user, + "password": credentials.password, + } - @property - def _iam_user_kwargs(self) -> Dict[str, Any]: - logger.debug("Connecting to redshift with 'iam' credentials method") - kwargs = self._iam_kwargs - - if self.credentials.access_key_id and self.credentials.secret_access_key: - kwargs.update( - access_key_id=self.credentials.access_key_id, - secret_access_key=self.credentials.secret_access_key, - ) - elif self.credentials.access_key_id or self.credentials.secret_access_key: + return __base_kwargs(credentials) | db_credentials + + def __iam_user_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam' credentials method") + + if credentials.access_key_id and credentials.secret_access_key: + iam_credentials = { + "access_key_id": credentials.access_key_id, + "secret_access_key": credentials.secret_access_key, + } + elif credentials.access_key_id or credentials.secret_access_key: raise FailedToConnectError( "'access_key_id' and 'secret_access_key' are both needed if providing explicit credentials" ) else: - kwargs.update(profile=self.credentials.iam_profile) + iam_credentials = {"profile": credentials.iam_profile} - if user := self.credentials.user: - kwargs.update(db_user=user) - else: - raise FailedToConnectError("'user' field is required for 'iam' credentials method") + __assert_required_fields(credentials, ["user"], "iam") + iam_credentials["db_user"] = credentials.user - return kwargs + return __iam_kwargs(credentials) | iam_credentials - @property - def _iam_role_kwargs(self) -> Dict[str, Optional[Any]]: - logger.debug("Connecting to redshift with 'iam_role' credentials method") - kwargs = self._iam_kwargs + def __iam_role_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam_role' credentials method") + role_kwargs = { + "db_user": None, + "group_federation": "serverless" not in credentials.host, + } - # It's a role, we're ignoring the user - kwargs.update(db_user=None) + if credentials.iam_profile: + role_kwargs["profile"] = credentials.iam_profile - # Serverless shouldn't get group_federation, Provisoned clusters should - if "serverless" in self.credentials.host: - kwargs.update(group_federation=False) - else: - kwargs.update(group_federation=True) + return __iam_kwargs(credentials) | role_kwargs - if iam_profile := self.credentials.iam_profile: - kwargs.update(profile=iam_profile) + def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam_idc_browser' credentials method") + identity_center_method_name = "BrowserIdcAuthPlugin" - return kwargs + if credentials.credentials_provider != identity_center_method_name: + raise FailedToConnectError( + f"'credentials_provider' must be set to '{identity_center_method_name}'" + ) - @property - def _iam_kwargs(self) -> Dict[str, Any]: - kwargs = self._base_kwargs - kwargs.update( - iam=True, - user="", - password="", + __assert_required_fields( + credentials, ["credentials_provider", "idc_region", "issuer_url"], "iam_idc_browser" ) - if "serverless" in self.credentials.host: - kwargs.update(cluster_identifier=None) - elif cluster_id := self.credentials.cluster_id: - kwargs.update(cluster_identifier=cluster_id) - else: + idc_kwargs = { + "credentials_provider": identity_center_method_name, + "idc_region": credentials.idc_region, + "issuer_url": credentials.issuer_url, + "idc_client_display_name": credentials.idc_client_display_name, + "idp_response_timeout": credentials.idp_response_timeout, + } + + return __iam_kwargs(credentials) | idc_kwargs + + def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam_idc_token' credentials method") + identity_center_method_name = "IdpTokenAuthPlugin" + + if credentials.credentials_provider != identity_center_method_name: raise FailedToConnectError( - "Failed to use IAM method:" - " 'cluster_id' must be provided for provisioned cluster" - " 'host' must be provided for serverless endpoint" + f"'credentials_provider' must be set to '{identity_center_method_name}'" ) - return kwargs + __assert_required_fields( + credentials, ["credentials_provider", "token", "token_type"], "iam_idc_token" + ) - @property - def _base_kwargs(self) -> Dict[str, Any]: - kwargs = { - "host": self.credentials.host, - "port": int(self.credentials.port) if self.credentials.port else int(5439), - "database": self.credentials.database, - "region": self.credentials.region, - "auto_create": self.credentials.autocreate, - "db_groups": self.credentials.db_groups, - "timeout": self.credentials.connect_timeout, + try: + _ = IdentityCenterTokenType(credentials.token_type) + except ValueError: + raise FailedToConnectError( + f"'token_type' must be set to one of {[token.value for token in iter(IdentityCenterTokenType)]}" + ) + + idc_token_kwargs = { + "credentials_provider": identity_center_method_name, + "token": credentials.token, + "token_type": credentials.token_type, } - redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode) - kwargs.update(redshift_ssl_config.to_dict()) - return kwargs + + return __iam_kwargs(credentials) | idc_token_kwargs + + # + # Head of function execution + # + + method_to_kwargs_function = { + None: __database_kwargs, + RedshiftConnectionMethod.DATABASE: __database_kwargs, + RedshiftConnectionMethod.IAM: __iam_user_kwargs, + RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs, + RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs, + RedshiftConnectionMethod.IAM_IDENTITY_CENTER_TOKEN: __iam_idc_token_kwargs, + } + + try: + kwargs_function = method_to_kwargs_function[credentials.method] + except KeyError: + raise FailedToConnectError(f"Invalid 'method' in profile: '{credentials.method}'") + + kwargs = kwargs_function(credentials) + + def connect() -> redshift_connector.Connection: + c = redshift_connector.connect(**kwargs) + if credentials.autocommit: + c.autocommit = True + if credentials.role: + c.cursor().execute(f"set role {credentials.role}") + return c + + return connect class RedshiftConnectionManager(SQLConnectionManager): @@ -373,7 +446,6 @@ def open(cls, connection): return connection credentials = connection.credentials - connect_method_factory = RedshiftConnectMethodFactory(credentials) def exponential_backoff(attempt: int): return attempt * attempt @@ -387,7 +459,7 @@ def exponential_backoff(attempt: int): open_connection = cls.retry_connection( connection, - connect=connect_method_factory.get_connect_method(), + connect=get_connection_method(credentials), logger=logger, retry_limit=credentials.retries, retry_timeout=exponential_backoff, diff --git a/tests/unit/test_auth_method.py b/tests/unit/test_auth_method.py index 55b1aad74..bc2876672 100644 --- a/tests/unit/test_auth_method.py +++ b/tests/unit/test_auth_method.py @@ -9,7 +9,7 @@ Plugin as RedshiftPlugin, RedshiftAdapter, ) -from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig +from dbt.adapters.redshift.connections import get_connection_method, RedshiftSSLConfig from tests.unit.utils import config_from_parts_or_dicts, inject_adapter @@ -61,7 +61,7 @@ def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate self.config.credentials.method = "badmethod" with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("badmethod" in context.exception.msg) @@ -221,7 +221,7 @@ def test_iam_optionals(self): def test_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method="iam") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -400,7 +400,7 @@ class TestIAMRoleMethod(AuthMethod): def test_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method="iam_role") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -573,3 +573,157 @@ def test_profile_invalid_serverless(self): **DEFAULT_SSL_CONFIG, ) self.assertTrue("'host' must be provided" in context.exception.msg) + + +class TestIAMIdcBrowser(AuthMethod): + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_browser_all_fields(self): + self.config.credentials = self.config.credentials.replace( + method="iam_idc_browser", + credentials_provider="BrowserIdcAuthPlugin", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + idc_client_display_name="display name", + idp_response_timeout=0, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + idp_response_timeout=0, + idc_client_display_name="display name", + credentials_provider="BrowserIdcAuthPlugin", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + ) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_browser_required_fields_only(self): + self.config.credentials = self.config.credentials.replace( + method="iam_idc_browser", + credentials_provider="BrowserIdcAuthPlugin", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + idp_response_timeout=60, + idc_client_display_name="Amazon Redshift driver", + credentials_provider="BrowserIdcAuthPlugin", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + ) + + def test_invalid_plugin_for_idc_browser_auth_method(self): + self.config.credentials = self.config.credentials.replace( + method="iam_idc_browser", + credentials_provider="IdpTokenAuthPlugin", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + + assert "BrowserIdcAuthPlugin" in context.exception.msg + + def test_invalid_adapter_missing_fields(self): + self.config.credentials = self.config.credentials.replace( + method="iam_idc_browser", + credentials_provider="BrowserIdcAuthPlugin", + idc_client_display_name="my display", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + + assert ( + "'idc_region', 'issuer_url' field(s) are required for 'iam_idc_browser' credentials method" + in context.exception.msg + ) + + +class TestIAMIdcToken(AuthMethod): + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_token_all_required_fields(self): + """Same as all possible fields""" + self.config.credentials = self.config.credentials.replace( + method="iam_idc_token", + credentials_provider="IdpTokenAuthPlugin", + token="token", + token_type="ACCESS_TOKEN", + host="doesnotexist.1235.us-east-2.redshift-serverless.amazonaws.com", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host="doesnotexist.1235.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + credentials_provider="IdpTokenAuthPlugin", + token="token", + token_type="ACCESS_TOKEN", + ) + + def test_invalid_plugin_for_idc_token_auth_method(self): + self.config.credentials = self.config.credentials.replace( + method="iam_idc_token", + token="token", + token_type="ACCESS_TOKEN", + credentials_provider="BrowserIdcAuthPlugin", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + + assert "IdpTokenAuthPlugin" in context.exception.msg + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_invalid_idc_token_missing_field(self): + # Successful test + self.config.credentials = self.config.credentials.replace( + method="iam_idc_token", + credentials_provider="IdpTokenAuthPlugin", + token_type="ACCESS_TOKEN", + host="doesnotexist.1235.us-east-2.redshift-serverless.amazonaws.com", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert ( + "'token' field(s) are required for 'iam_idc_token' credentials method" + in context.exception.msg + )