Skip to content

Commit

Permalink
adding SSO support for redshift
Browse files Browse the repository at this point in the history
Committer: Abby Whittier <[email protected]>
  • Loading branch information
Whittier committed Oct 4, 2023
1 parent 85d3720 commit 26b6355
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
24 changes: 23 additions & 1 deletion dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_message(self) -> str:
class RedshiftConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"
IAMR = "iamr"


class UserSSLMode(StrEnum):
Expand Down Expand Up @@ -104,9 +105,9 @@ def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig":
@dataclass
class RedshiftCredentials(Credentials):
host: str
user: str
port: Port
method: str = RedshiftConnectionMethod.DATABASE # type: ignore
user: Optional[str] = None # type: ignore
password: Optional[str] = None # type: ignore
cluster_id: Optional[str] = field(
default=None,
Expand Down Expand Up @@ -226,6 +227,27 @@ def connect():
c.cursor().execute("set role {}".format(self.credentials.role))
return c

elif method == RedshiftConnectionMethod.IAMR:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise dbt.exceptions.FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)

def connect():
logger.debug("Connecting to redshift with IAM based auth...")
c = redshift_connector.connect(
iam=True,
cluster_identifier=self.credentials.cluster_id,
profile=self.credentials.iam_profile,
**kwargs,
)
if self.credentials.autocommit:
c.autocommit = True
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
Expand Down
2 changes: 2 additions & 0 deletions dbt/include/redshift/profile_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ prompts:
hide_input: true
iam:
_fixed_method: iam
iamr:
_fixed_method: iamr
dbname:
hint: 'default database that dbt will build objects in'
schema:
Expand Down
54 changes: 53 additions & 1 deletion tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
)
from dbt.clients import agate_helper
from dbt.exceptions import FailedToConnectError
from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig
from dbt.adapters.redshift.connections import (
RedshiftConnectMethodFactory,
RedshiftSSLConfig,
)
from .utils import (
config_from_parts_or_dicts,
mock_connection,
Expand Down Expand Up @@ -199,6 +202,55 @@ def test_explicit_iam_serverless_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
def test_explicit_iamr_conn_without_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iamr",
cluster_id="my_redshift",
host="thishostshouldnotexist.test.us-east-1",
)
connection = self.adapter.acquire_connection("dummy")
connection.handle
redshift_connector.connect.assert_called_once_with(
iam=True,
host="thishostshouldnotexist.test.us-east-1",
database="redshift",
cluster_identifier="my_redshift",
region=None,
timeout=None,
auto_create=False,
db_groups=[],
profile=None,
port=5439,
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("boto3.Session", Mock())
def test_explicit_iamr_conn_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iamr",
cluster_id="my_redshift",
iam_profile="test",
host="thishostshouldnotexist.test.us-east-1",
)
connection = self.adapter.acquire_connection("dummy")
connection.handle

redshift_connector.connect.assert_called_once_with(
iam=True,
host="thishostshouldnotexist.test.us-east-1",
database="redshift",
cluster_identifier="my_redshift",
region=None,
auto_create=False,
db_groups=[],
profile="test",
timeout=None,
port=5439,
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("boto3.Session", Mock())
def test_explicit_region(self):
Expand Down

0 comments on commit 26b6355

Please sign in to comment.