diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index ac12adf2a..224adef88 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -49,6 +49,10 @@ class RedshiftConnectionMethod(StrEnum): IAM_IDENTITY_CENTER_BROWSER = "iam_idc_browser" IAM_IDENTITY_CENTER_TOKEN = "iam_idc_token" + @staticmethod + def uses_identity_center(method: str): + return "_idc_" in method + class UserSSLMode(StrEnum): disable = "disable" @@ -192,18 +196,20 @@ def get_connection_method( # # Helper Methods # - def __assert_required_fields(credentials, required_fields, method_name): - missing_fields = [ + def __assert_required_fields(method_name: str, required_fields: Tuple[str, ...]): + missing_fields: List[str] = [ field for field in required_fields if getattr(credentials, field, None) is None ] if missing_fields: - fields_str = "', '".join(missing_fields) + fields_str: str = "', '".join(missing_fields) raise FailedToConnectError( f"'{fields_str}' field(s) are required for '{method_name}' credentials method" ) def __base_kwargs(credentials) -> Dict[str, Any]: - redshift_ssl_config = RedshiftSSLConfig.parse(credentials.sslmode).to_dict() + redshift_ssl_config: Dict[str, Any] = RedshiftSSLConfig.parse( + credentials.sslmode + ).to_dict() return { "host": credentials.host, "port": int(credentials.port) if credentials.port else 5439, @@ -217,7 +223,13 @@ def __base_kwargs(credentials) -> Dict[str, Any]: def __iam_kwargs(credentials) -> Dict[str, Any]: - if "serverless" in credentials.host: + # iam True except for identity center methods + iam: bool = not RedshiftConnectionMethod.uses_identity_center(credentials.method) + + cluster_identifier: Optional[str] + if "serverless" in credentials.host or RedshiftConnectionMethod.uses_identity_center( + credentials.method + ): cluster_identifier = None elif credentials.cluster_id: cluster_identifier = credentials.cluster_id @@ -228,8 +240,8 @@ def __iam_kwargs(credentials) -> Dict[str, Any]: " 'host' must be provided for serverless endpoint" ) - iam_specific_kwargs = { - "iam": True, + iam_specific_kwargs: Dict[str, Any] = { + "iam": iam, "user": "", "password": "", "cluster_identifier": cluster_identifier, @@ -240,9 +252,9 @@ def __iam_kwargs(credentials) -> Dict[str, Any]: def __database_kwargs(credentials) -> Dict[str, Any]: logger.debug("Connecting to Redshift with 'database' credentials method") - __assert_required_fields(credentials, ["user", "password"], "database") + __assert_required_fields("database", ("user", "password")) - db_credentials = { + db_credentials: Dict[str, Any] = { "user": credentials.user, "password": credentials.password, } @@ -252,6 +264,7 @@ def __database_kwargs(credentials) -> Dict[str, Any]: def __iam_user_kwargs(credentials) -> Dict[str, Any]: logger.debug("Connecting to Redshift with 'iam' credentials method") + iam_credentials: Dict[str, Any] if credentials.access_key_id and credentials.secret_access_key: iam_credentials = { "access_key_id": credentials.access_key_id, @@ -264,7 +277,7 @@ def __iam_user_kwargs(credentials) -> Dict[str, Any]: else: iam_credentials = {"profile": credentials.iam_profile} - __assert_required_fields(credentials, ["user"], "iam") + __assert_required_fields("iam", ("user",)) iam_credentials["db_user"] = credentials.user return __iam_kwargs(credentials) | iam_credentials @@ -283,7 +296,7 @@ def __iam_role_kwargs(credentials) -> Dict[str, Any]: def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: logger.debug("Connecting to Redshift with 'iam_idc_browser' credentials method") - identity_center_method_name = "BrowserIdcAuthPlugin" + identity_center_method_name: str = "BrowserIdcAuthPlugin" if credentials.credentials_provider != identity_center_method_name: raise FailedToConnectError( @@ -291,13 +304,13 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: ) __assert_required_fields( - credentials, ["credentials_provider", "idc_region", "issuer_url"], "iam_idc_browser" + "iam_idc_browser", ("credentials_provider", "idc_region", "issuer_url") ) - idc_kwargs = { + idc_kwargs: Dict[str, Any] = { "credentials_provider": identity_center_method_name, - "idc_region": credentials.idc_region, "issuer_url": credentials.issuer_url, + "idc_region": credentials.idc_region, "idc_client_display_name": credentials.idc_client_display_name, "idp_response_timeout": credentials.idp_response_timeout, } @@ -306,16 +319,14 @@ def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]: logger.debug("Connecting to Redshift with 'iam_idc_token' credentials method") - identity_center_method_name = "IdpTokenAuthPlugin" + identity_center_method_name: str = "IdpTokenAuthPlugin" if credentials.credentials_provider != identity_center_method_name: raise FailedToConnectError( f"'credentials_provider' must be set to '{identity_center_method_name}'" ) - __assert_required_fields( - credentials, ["credentials_provider", "token", "token_type"], "iam_idc_token" - ) + __assert_required_fields("iam_idc_token", ("credentials_provider", "token", "token_type")) try: _ = IdentityCenterTokenType(credentials.token_type) @@ -324,7 +335,7 @@ def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]: f"'token_type' must be set to one of {[token.value for token in iter(IdentityCenterTokenType)]}" ) - idc_token_kwargs = { + idc_token_kwargs: Dict[str, Any] = { "credentials_provider": identity_center_method_name, "token": credentials.token, "token_type": credentials.token_type, @@ -346,11 +357,13 @@ def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]: } try: - kwargs_function = method_to_kwargs_function[credentials.method] + kwargs_function: Callable[[RedshiftCredentials], Dict[str, Any]] = ( + method_to_kwargs_function[credentials.method] + ) except KeyError: raise FailedToConnectError(f"Invalid 'method' in profile: '{credentials.method}'") - kwargs = kwargs_function(credentials) + kwargs: Dict[str, Any] = kwargs_function(credentials) def connect() -> redshift_connector.Connection: c = redshift_connector.connect(**kwargs) diff --git a/tests/unit/test_auth_method.py b/tests/unit/test_auth_method.py index bc2876672..a1f0d3d3f 100644 --- a/tests/unit/test_auth_method.py +++ b/tests/unit/test_auth_method.py @@ -79,7 +79,7 @@ def test_missing_region_failure(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - iam=True, + iam=False, host="doesnotexist.1233_no_region", database="redshift", cluster_identifier=None, @@ -590,7 +590,7 @@ def test_profile_idc_browser_all_fields(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - iam=True, + iam=False, host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", database="redshift", cluster_identifier=None, @@ -621,7 +621,7 @@ def test_profile_idc_browser_required_fields_only(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - iam=True, + iam=False, host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", database="redshift", cluster_identifier=None, @@ -681,7 +681,7 @@ def test_profile_idc_token_all_required_fields(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - iam=True, + iam=False, host="doesnotexist.1235.us-east-2.redshift-serverless.amazonaws.com", database="redshift", cluster_identifier=None,