diff --git a/.changes/unreleased/Features-20241122-143326.yaml b/.changes/unreleased/Features-20241122-143326.yaml new file mode 100644 index 000000000..a4b8a7089 --- /dev/null +++ b/.changes/unreleased/Features-20241122-143326.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add browser identity center authentication method. +time: 2024-11-22T14:33:26.549878-08:00 +custom: + Author: versusfacit + Issue: "898" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 521bcff86..d93847634 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -37,10 +37,24 @@ def get_message(self) -> str: logger = AdapterLogger("Redshift") +class IdentityCenterTokenType(StrEnum): + ACCESS_TOKEN = "ACCESS_TOKEN" + EXT_JWT = "EXT_JWT" + + class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" IAM_ROLE = "iam_role" + IAM_IDENTITY_CENTER_BROWSER = "browser_identity_center" + + @classmethod + def uses_identity_center(cls, method: str) -> bool: + return method in (cls.IAM_IDENTITY_CENTER_BROWSER,) + + @classmethod + def is_iam(cls, method: str) -> bool: + return not cls.uses_identity_center(method) class UserSSLMode(StrEnum): @@ -128,6 +142,17 @@ class RedshiftCredentials(Credentials): access_key_id: Optional[str] = None secret_access_key: Optional[str] = None + # + # IAM identity center methods + # + + # browser + idc_region: Optional[str] = None + issuer_url: Optional[str] = None + idp_listen_port: Optional[int] = 7890 + idc_client_display_name: Optional[str] = "Amazon Redshift driver" + idp_response_timeout: Optional[int] = None + _ALIASES = {"dbname": "database", "pass": "password"} @property @@ -163,131 +188,171 @@ def unique_field(self) -> str: return self.host -class RedshiftConnectMethodFactory: - credentials: RedshiftCredentials - - def __init__(self, credentials) -> None: - self.credentials = credentials - - def get_connect_method(self) -> Callable[[], redshift_connector.Connection]: +def get_connection_method( + credentials: RedshiftCredentials, +) -> Callable[[], redshift_connector.Connection]: + # + # Helper Methods + # + def __validate_required_fields(method_name: str, required_fields: Tuple[str, ...]): + missing_fields: List[str] = [ + field for field in required_fields if getattr(credentials, field, None) is None + ] + if missing_fields: + fields_str: str = "', '".join(missing_fields) + raise FailedToConnectError( + f"'{fields_str}' field(s) are required for '{method_name}' credentials method" + ) - # Support missing 'method' for backwards compatibility - method = self.credentials.method or RedshiftConnectionMethod.DATABASE - if method == RedshiftConnectionMethod.DATABASE: - kwargs = self._database_kwargs - elif method == RedshiftConnectionMethod.IAM: - kwargs = self._iam_user_kwargs - elif method == RedshiftConnectionMethod.IAM_ROLE: - kwargs = self._iam_role_kwargs - else: - raise FailedToConnectError(f"Invalid 'method' in profile: '{method}'") + def __base_kwargs(credentials) -> Dict[str, Any]: + redshift_ssl_config: Dict[str, Any] = RedshiftSSLConfig.parse( + credentials.sslmode + ).to_dict() + return { + "host": credentials.host, + "port": int(credentials.port) if credentials.port else 5439, + "database": credentials.database, + "region": credentials.region, + "auto_create": credentials.autocreate, + "db_groups": credentials.db_groups, + "timeout": credentials.connect_timeout, + **redshift_ssl_config, + } - def connect() -> redshift_connector.Connection: - c = redshift_connector.connect(**kwargs) - if self.credentials.autocommit: - c.autocommit = True - if self.credentials.role: - c.cursor().execute(f"set role {self.credentials.role}") - return c + def __iam_kwargs(credentials) -> Dict[str, Any]: - return connect + # iam True except for identity center methods + iam: bool = RedshiftConnectionMethod.is_iam(credentials.method) - @property - def _database_kwargs(self) -> Dict[str, Any]: - logger.debug("Connecting to redshift with 'database' credentials method") - kwargs = self._base_kwargs - - if self.credentials.user and self.credentials.password: - kwargs.update( - user=self.credentials.user, - password=self.credentials.password, - ) + cluster_identifier: Optional[str] + if "serverless" in credentials.host or RedshiftConnectionMethod.uses_identity_center( + credentials.method + ): + cluster_identifier = None + elif credentials.cluster_id: + cluster_identifier = credentials.cluster_id else: raise FailedToConnectError( - "'user' and 'password' fields are required for 'database' credentials method" + "Failed to use IAM method:" + " 'cluster_id' must be provided for provisioned cluster" + " 'host' must be provided for serverless endpoint" ) - return kwargs + iam_specific_kwargs: Dict[str, Any] = { + "iam": iam, + "user": "", + "password": "", + "cluster_identifier": cluster_identifier, + } + + return __base_kwargs(credentials) | iam_specific_kwargs - @property - def _iam_user_kwargs(self) -> Dict[str, Any]: - logger.debug("Connecting to redshift with 'iam' credentials method") - kwargs = self._iam_kwargs - - if self.credentials.access_key_id and self.credentials.secret_access_key: - kwargs.update( - access_key_id=self.credentials.access_key_id, - secret_access_key=self.credentials.secret_access_key, - ) - elif self.credentials.access_key_id or self.credentials.secret_access_key: + def __database_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'database' credentials method") + + __validate_required_fields("database", ("user", "password")) + + db_credentials: Dict[str, Any] = { + "user": credentials.user, + "password": credentials.password, + } + + return __base_kwargs(credentials) | db_credentials + + def __iam_user_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam' credentials method") + + iam_credentials: Dict[str, Any] + if credentials.access_key_id and credentials.secret_access_key: + iam_credentials = { + "access_key_id": credentials.access_key_id, + "secret_access_key": credentials.secret_access_key, + } + elif credentials.access_key_id or credentials.secret_access_key: raise FailedToConnectError( "'access_key_id' and 'secret_access_key' are both needed if providing explicit credentials" ) else: - kwargs.update(profile=self.credentials.iam_profile) + iam_credentials = {"profile": credentials.iam_profile} - if user := self.credentials.user: - kwargs.update(db_user=user) - else: - raise FailedToConnectError("'user' field is required for 'iam' credentials method") + __validate_required_fields("iam", ("user",)) + iam_credentials["db_user"] = credentials.user - return kwargs + return __iam_kwargs(credentials) | iam_credentials - @property - def _iam_role_kwargs(self) -> Dict[str, Optional[Any]]: - logger.debug("Connecting to redshift with 'iam_role' credentials method") - kwargs = self._iam_kwargs + def __iam_role_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with 'iam_role' credentials method") + role_kwargs = { + "db_user": None, + "group_federation": "serverless" not in credentials.host, + } - # It's a role, we're ignoring the user - kwargs.update(db_user=None) + if credentials.iam_profile: + role_kwargs["profile"] = credentials.iam_profile - # Serverless shouldn't get group_federation, Provisoned clusters should - if "serverless" in self.credentials.host: - kwargs.update(group_federation=False) - else: - kwargs.update(group_federation=True) + return __iam_kwargs(credentials) | role_kwargs - if iam_profile := self.credentials.iam_profile: - kwargs.update(profile=iam_profile) + def __iam_idc_browser_kwargs(credentials) -> Dict[str, Any]: + logger.debug("Connecting to Redshift with '{credentials.method}' credentials method") - return kwargs + __IDP_TIMEOUT: int = 60 + __LISTEN_PORT_DEFAULT: int = 7890 - @property - def _iam_kwargs(self) -> Dict[str, Any]: - kwargs = self._base_kwargs - kwargs.update( - iam=True, - user="", - password="", + __validate_required_fields( + "browser_identity_center", ("method", "idc_region", "issuer_url") ) - if "serverless" in self.credentials.host: - kwargs.update(cluster_identifier=None) - elif cluster_id := self.credentials.cluster_id: - kwargs.update(cluster_identifier=cluster_id) - else: - raise FailedToConnectError( - "Failed to use IAM method:" - " 'cluster_id' must be provided for provisioned cluster" - " 'host' must be provided for serverless endpoint" - ) + idp_timeout: int = ( + timeout + if (timeout := credentials.idp_response_timeout) or timeout == 0 + else __IDP_TIMEOUT + ) - return kwargs + idp_listen_port: int = ( + port if (port := credentials.idp_listen_port) else __LISTEN_PORT_DEFAULT + ) - @property - def _base_kwargs(self) -> Dict[str, Any]: - kwargs = { - "host": self.credentials.host, - "port": int(self.credentials.port) if self.credentials.port else int(5439), - "database": self.credentials.database, - "region": self.credentials.region, - "auto_create": self.credentials.autocreate, - "db_groups": self.credentials.db_groups, - "timeout": self.credentials.connect_timeout, + idc_kwargs: Dict[str, Any] = { + "credentials_provider": "BrowserIdcAuthPlugin", + "issuer_url": credentials.issuer_url, + "listen_port": idp_listen_port, + "idc_region": credentials.idc_region, + "idc_client_display_name": credentials.idc_client_display_name, + "idp_response_timeout": idp_timeout, } - redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode) - kwargs.update(redshift_ssl_config.to_dict()) - return kwargs + + return __iam_kwargs(credentials) | idc_kwargs + + # + # Head of function execution + # + + method_to_kwargs_function = { + None: __database_kwargs, + RedshiftConnectionMethod.DATABASE: __database_kwargs, + RedshiftConnectionMethod.IAM: __iam_user_kwargs, + RedshiftConnectionMethod.IAM_ROLE: __iam_role_kwargs, + RedshiftConnectionMethod.IAM_IDENTITY_CENTER_BROWSER: __iam_idc_browser_kwargs, + } + + try: + kwargs_function: Callable[[RedshiftCredentials], Dict[str, Any]] = ( + method_to_kwargs_function[credentials.method] + ) + except KeyError: + raise FailedToConnectError(f"Invalid 'method' in profile: '{credentials.method}'") + + kwargs: Dict[str, Any] = kwargs_function(credentials) + + def connect() -> redshift_connector.Connection: + c = redshift_connector.connect(**kwargs) + if credentials.autocommit: + c.autocommit = True + if credentials.role: + c.cursor().execute(f"set role {credentials.role}") + return c + + return connect class RedshiftConnectionManager(SQLConnectionManager): @@ -373,7 +438,6 @@ def open(cls, connection): return connection credentials = connection.credentials - connect_method_factory = RedshiftConnectMethodFactory(credentials) def exponential_backoff(attempt: int): return attempt * attempt @@ -387,7 +451,7 @@ def exponential_backoff(attempt: int): open_connection = cls.retry_connection( connection, - connect=connect_method_factory.get_connect_method(), + connect=get_connection_method(credentials), logger=logger, retry_limit=credentials.retries, retry_timeout=exponential_backoff, diff --git a/setup.py b/setup.py index 675b8588e..fb1530524 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def _plugin_version() -> str: "dbt-postgres>=1.8,<1.10", # dbt-redshift depends deeply on this package. it does not follow SemVer, therefore there have been breaking changes in previous patch releases # Pin to the patch or minor version, and bump in each new minor version of dbt-redshift. - "redshift-connector<2.1.1,>=2.0.913,!=2.0.914", + "redshift-connector>=2.1.3,<2.2", # add dbt-core to ensure backwards compatibility of installation, this is not a functional dependency "dbt-core>=1.8.0b3", # installed via dbt-core but referenced directly; don't pin to avoid version conflicts with dbt-core diff --git a/tests/unit/test_auth_method.py b/tests/unit/test_auth_method.py index 55b1aad74..16d13268f 100644 --- a/tests/unit/test_auth_method.py +++ b/tests/unit/test_auth_method.py @@ -9,7 +9,7 @@ Plugin as RedshiftPlugin, RedshiftAdapter, ) -from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig +from dbt.adapters.redshift.connections import get_connection_method, RedshiftSSLConfig from tests.unit.utils import config_from_parts_or_dicts, inject_adapter @@ -61,7 +61,7 @@ def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate self.config.credentials.method = "badmethod" with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("badmethod" in context.exception.msg) @@ -221,7 +221,7 @@ def test_iam_optionals(self): def test_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method="iam") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -400,7 +400,7 @@ class TestIAMRoleMethod(AuthMethod): def test_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method="iam_role") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = get_connection_method(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -573,3 +573,103 @@ def test_profile_invalid_serverless(self): **DEFAULT_SSL_CONFIG, ) self.assertTrue("'host' must be provided" in context.exception.msg) + + +class TestIAMIdcBrowser(AuthMethod): + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_browser_all_fields(self): + self.config.credentials = self.config.credentials.replace( + method="browser_identity_center", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + idc_client_display_name="display name", + idp_response_timeout=0, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + idp_listen_port=1111, + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=False, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + idp_response_timeout=0, + idc_client_display_name="display name", + credentials_provider="BrowserIdcAuthPlugin", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + listen_port=1111, + ) + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_profile_idc_browser_required_fields_only(self): + self.config.credentials = self.config.credentials.replace( + method="browser_identity_center", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=False, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + credentials_provider="BrowserIdcAuthPlugin", + listen_port=7890, + idp_response_timeout=60, + idc_client_display_name="Amazon Redshift driver", + idc_region="us-east-1", + issuer_url="https://identitycenter.amazonaws.com/ssoins-randomchars", + ) + + def test_invalid_adapter_missing_fields(self): + self.config.credentials = self.config.credentials.replace( + method="browser_identity_center", + idp_listen_port=1111, + idc_client_display_name="my display", + ) + with self.assertRaises(FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=False, + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", + cluster_identifier=None, + region=None, + auto_create=False, + db_groups=[], + password="", + user="", + timeout=None, + port=5439, + **DEFAULT_SSL_CONFIG, + credentials_provider="BrowserIdcAuthPlugin", + listen_port=1111, + idp_response_timeout=60, + idc_client_display_name="my display", + ) + + assert ( + "'idc_region', 'issuer_url' field(s) are required for 'browser_identity_center' credentials method" + in context.exception.msg + )