Skip to content

Commit

Permalink
Add authentication methods and unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
VersusFacit committed Nov 14, 2024
1 parent cf5acf1 commit 8aa3a02
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 103 deletions.
270 changes: 171 additions & 99 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@ def get_message(self) -> str:
logger = AdapterLogger("Redshift")


class IdentityCenterTokenType(StrEnum):
ACCESS_TOKEN = "ACCESS_TOKEN"
EXT_JWT = "EXT_JWT"


class RedshiftConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"
IAM_ROLE = "iam_role"
IAM_IDENTITY_CENTER_BROWSER = "iam_idc_browser"
IAM_IDENTITY_CENTER_TOKEN = "iam_idc_token"


class UserSSLMode(StrEnum):
Expand Down Expand Up @@ -128,6 +135,22 @@ class RedshiftCredentials(Credentials):
access_key_id: Optional[str] = None
secret_access_key: Optional[str] = None

#
# IAM identity center methods
#

# browser
credentials_provider: Optional[str] = None
idc_region: Optional[str] = None
issuer_url: Optional[str] = None
listen_port: int = 7890
idc_client_display_name: Optional[str] = "Amazon Redshift driver"
idp_response_timeout: int = 60

# token
token: Optional[str] = None
token_type: Optional[str] = None

_ALIASES = {"dbname": "database", "pass": "password"}

@property
Expand Down Expand Up @@ -163,131 +186,181 @@ def unique_field(self) -> str:
return self.host


class RedshiftConnectMethodFactory:
credentials: RedshiftCredentials
def get_connection_method(
credentials: RedshiftCredentials,
) -> Callable[[], redshift_connector.Connection]:
#
# Helper Methods
#
def __assert_required_fields(credentials, required_fields, method_name):
missing_fields = [
field for field in required_fields if getattr(credentials, field, None) is None
]
if missing_fields:
fields_str = "', '".join(missing_fields)
raise FailedToConnectError(
f"'{fields_str}' field(s) are required for '{method_name}' credentials method"
)

def __init__(self, credentials) -> None:
self.credentials = credentials
def __base_kwargs(credentials) -> Dict[str, Any]:
redshift_ssl_config = RedshiftSSLConfig.parse(credentials.sslmode).to_dict()
return {
"host": credentials.host,
"port": int(credentials.port) if credentials.port else 5439,
"database": credentials.database,
"region": credentials.region,
"auto_create": credentials.autocreate,
"db_groups": credentials.db_groups,
"timeout": credentials.connect_timeout,
**redshift_ssl_config,
}

def get_connect_method(self) -> Callable[[], redshift_connector.Connection]:
def __iam_kwargs(credentials) -> Dict[str, Any]:

# Support missing 'method' for backwards compatibility
method = self.credentials.method or RedshiftConnectionMethod.DATABASE
if method == RedshiftConnectionMethod.DATABASE:
kwargs = self._database_kwargs
elif method == RedshiftConnectionMethod.IAM:
kwargs = self._iam_user_kwargs
elif method == RedshiftConnectionMethod.IAM_ROLE:
kwargs = self._iam_role_kwargs
if "serverless" in credentials.host:
cluster_identifier = None
elif credentials.cluster_id:
cluster_identifier = credentials.cluster_id
else:
raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'")
raise FailedToConnectError(
"Failed to use IAM method:"
" 'cluster_id' must be provided for provisioned cluster"
" 'host' must be provided for serverless endpoint"
)

iam_specific_kwargs = {
"iam": True,
"user": "",
"password": "",
"cluster_identifier": cluster_identifier,
}

def connect() -> redshift_connector.Connection:
c = redshift_connector.connect(**kwargs)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute(f"set role {self.credentials.role}")
return c
return __base_kwargs(credentials) | iam_specific_kwargs

return connect
def __database_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'database' credentials method")

@property
def _database_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'database' credentials method")
kwargs = self._base_kwargs

if self.credentials.user and self.credentials.password:
kwargs.update(
user=self.credentials.user,
password=self.credentials.password,
)
else:
raise FailedToConnectError(
"'user' and 'password' fields are required for 'database' credentials method"
)
__assert_required_fields(credentials, ["user", "password"], "database")

return kwargs
db_credentials = {
"user": credentials.user,
"password": credentials.password,
}

@property
def _iam_user_kwargs(self) -> Dict[str, Any]:
logger.debug("Connecting to redshift with 'iam' credentials method")
kwargs = self._iam_kwargs

if self.credentials.access_key_id and self.credentials.secret_access_key:
kwargs.update(
access_key_id=self.credentials.access_key_id,
secret_access_key=self.credentials.secret_access_key,
)
elif self.credentials.access_key_id or self.credentials.secret_access_key:
return __base_kwargs(credentials) | db_credentials

def __iam_user_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam' credentials method")

if credentials.access_key_id and credentials.secret_access_key:
iam_credentials = {
"access_key_id": credentials.access_key_id,
"secret_access_key": credentials.secret_access_key,
}
elif credentials.access_key_id or credentials.secret_access_key:
raise FailedToConnectError(
"'access_key_id' and 'secret_access_key' are both needed if providing explicit credentials"
)
else:
kwargs.update(profile=self.credentials.iam_profile)
iam_credentials = {"profile": credentials.iam_profile}

if user := self.credentials.user:
kwargs.update(db_user=user)
else:
raise FailedToConnectError("'user' field is required for 'iam' credentials method")
__assert_required_fields(credentials, ["user"], "iam")
iam_credentials["db_user"] = credentials.user

return kwargs
return __iam_kwargs(credentials) | iam_credentials

@property
def _iam_role_kwargs(self) -> Dict[str, Optional[Any]]:
logger.debug("Connecting to redshift with 'iam_role' credentials method")
kwargs = self._iam_kwargs
def __iam_role_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam_role' credentials method")
role_kwargs = {
"db_user": None,
"group_federation": "serverless" not in credentials.host,
}

# It's a role, we're ignoring the user
kwargs.update(db_user=None)
if credentials.iam_profile:
role_kwargs["profile"] = credentials.iam_profile

# Serverless shouldn't get group_federation, Provisoned clusters should
if "serverless" in self.credentials.host:
kwargs.update(group_federation=False)
else:
kwargs.update(group_federation=True)
return __iam_kwargs(credentials) | role_kwargs

if iam_profile := self.credentials.iam_profile:
kwargs.update(profile=iam_profile)
def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam_idc_browser' credentials method")
identity_center_method_name = "BrowserIdcAuthPlugin"

return kwargs
if credentials.credentials_provider != identity_center_method_name:
raise FailedToConnectError(
f"'credentials_provider' must be set to '{identity_center_method_name}'"
)

@property
def _iam_kwargs(self) -> Dict[str, Any]:
kwargs = self._base_kwargs
kwargs.update(
iam=True,
user="",
password="",
__assert_required_fields(
credentials, ["credentials_provider", "idc_region", "issuer_url"], "iam_idc_browser"
)

if "serverless" in self.credentials.host:
kwargs.update(cluster_identifier=None)
elif cluster_id := self.credentials.cluster_id:
kwargs.update(cluster_identifier=cluster_id)
else:
idc_kwargs = {
"credentials_provider": identity_center_method_name,
"idc_region": credentials.idc_region,
"issuer_url": credentials.issuer_url,
"idc_client_display_name": credentials.idc_client_display_name,
"idp_response_timeout": credentials.idp_response_timeout,
}

return __iam_kwargs(credentials) | idc_kwargs

def __iam_idc_token_kwargs(credentials) -> Dict[str, Any]:
logger.debug("Connecting to Redshift with 'iam_idc_token' credentials method")
identity_center_method_name = "IdpTokenAuthPlugin"

if credentials.credentials_provider != identity_center_method_name:
raise FailedToConnectError(
"Failed to use IAM method:"
" 'cluster_id' must be provided for provisioned cluster"
" 'host' must be provided for serverless endpoint"
f"'credentials_provider' must be set to '{identity_center_method_name}'"
)

return kwargs
__assert_required_fields(
credentials, ["credentials_provider", "token", "token_type"], "iam_idc_token"
)

@property
def _base_kwargs(self) -> Dict[str, Any]:
kwargs = {
"host": self.credentials.host,
"port": int(self.credentials.port) if self.credentials.port else int(5439),
"database": self.credentials.database,
"region": self.credentials.region,
"auto_create": self.credentials.autocreate,
"db_groups": self.credentials.db_groups,
"timeout": self.credentials.connect_timeout,
try:
_ = IdentityCenterTokenType(credentials.token_type)
except ValueError:
raise FailedToConnectError(
f"'token_type' must be set to one of {[token.value for token in iter(IdentityCenterTokenType)]}"
)

idc_token_kwargs = {
"credentials_provider": identity_center_method_name,
"token": credentials.token,
"token_type": credentials.token_type,
}
redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode)
kwargs.update(redshift_ssl_config.to_dict())
return kwargs

return __iam_kwargs(credentials) | idc_token_kwargs

#
# Head of function execution
#

method_to_kwargs_function = {
None: __database_kwargs,
RedshiftConnectionMethod.DATABASE: __database_kwargs,
RedshiftConnectionMethod.IAM: __iam_user_kwargs,
RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs,
RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs,
RedshiftConnectionMethod.IAM_IDENTITY_CENTER_TOKEN: __iam_idc_token_kwargs,
}

try:
kwargs_function = method_to_kwargs_function[credentials.method]
except KeyError:
raise FailedToConnectError(f"Invalid 'method' in profile: '{credentials.method}'")

kwargs = kwargs_function(credentials)

def connect() -> redshift_connector.Connection:
c = redshift_connector.connect(**kwargs)
if credentials.autocommit:
c.autocommit = True
if credentials.role:
c.cursor().execute(f"set role {credentials.role}")
return c

return connect


class RedshiftConnectionManager(SQLConnectionManager):
Expand Down Expand Up @@ -373,7 +446,6 @@ def open(cls, connection):
return connection

credentials = connection.credentials
connect_method_factory = RedshiftConnectMethodFactory(credentials)

def exponential_backoff(attempt: int):
return attempt * attempt
Expand All @@ -387,7 +459,7 @@ def exponential_backoff(attempt: int):

open_connection = cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
connect=get_connection_method(credentials),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
Expand Down
Loading

0 comments on commit 8aa3a02

Please sign in to comment.