Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify RedshiftConnectionManager to extend from SQLConnectionManager, migrate from psycopg2 to redshift python connector #251

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ab4b269
Change RedshiftConnectionManager to extend from SQLConnectionManager,…
sathiish-kumar Dec 19, 2022
ff9fdfd
Add/fix unit tests, create RedshiftConnectMethodFactory to vend conne…
sathiish-kumar Dec 30, 2022
fbd5731
Fix _connection_keys to mimic PostgresConnectionManager
sathiish-kumar Jan 3, 2023
4f98546
Remove unneeded functions for tmp_cluster_creds and env_var creds aut…
sathiish-kumar Jan 9, 2023
f724708
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 18, 2023
5319b90
Resolve some TODOs
sathiish-kumar Jan 18, 2023
16666db
Fix references to old exceptions, add changelog
sathiish-kumar Jan 18, 2023
30ae0b5
Fix errors with functional tests by overriding add_query & execute an…
sathiish-kumar Jan 23, 2023
de7c411
Merge branch 'dbt-labs:main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 24, 2023
bfe8678
Attempt to fix integration tests by adding `valid_incremental_strateg…
sathiish-kumar Jan 24, 2023
c8a18d8
Fix unit tests
sathiish-kumar Jan 24, 2023
40e0fe5
Attempt to fix integration tests
sathiish-kumar Jan 25, 2023
c74d3ee
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 26, 2023
66c1594
add unit tests for execute
jiezhen-chen Jan 26, 2023
54bc39f
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 27, 2023
3ed9876
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 27, 2023
4bb97ab
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
sathiish-kumar Jan 30, 2023
cfad7ff
add unit tests for add_query
jiezhen-chen Jan 30, 2023
b1c8e00
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
colin-rogers-dbt Jan 30, 2023
12eb89b
make get_connection_method work with serverless
jiezhen-chen Jan 31, 2023
9a319ac
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
colin-rogers-dbt Jan 31, 2023
d3113ca
add unit tests for serverless iam connections
jiezhen-chen Jan 31, 2023
f5743b4
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
colin-rogers-dbt Feb 8, 2023
880941c
add redshift connector version, remove sslmode, connection time out, …
jiezhen-chen Feb 9, 2023
8527832
change redshift_connector version
jiezhen-chen Feb 9, 2023
405a702
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
colin-rogers-dbt Feb 10, 2023
0a5899e
Merge branch 'main' into migrate_psycopg2_to_rshift_connector
colin-rogers-dbt Feb 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changes/unreleased/Under the Hood-20230118-071542.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
kind: Under the Hood
body: Replace psycopg2 connector with Redshift python connector when connecting to
Redshift
time: 2023-01-18T07:15:42.183304-08:00
custom:
Author: sathiish-kumar
Issue: "219"
PR: "251"
272 changes: 191 additions & 81 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import re
from multiprocessing import Lock
from contextlib import contextmanager
from typing import NewType
from typing import NewType, Tuple

from dbt.adapters.postgres import PostgresConnectionManager
from dbt.adapters.postgres import PostgresCredentials
import agate
import sqlparse
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.events import AdapterLogger
import dbt.exceptions
import dbt.flags

import boto3

import redshift_connector
from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum

from dataclasses import dataclass, field
from typing import Optional, List

from dbt.helper_types import Port
from redshift_connector import OperationalError, DatabaseError, DataError

logger = AdapterLogger("Redshift")

drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore
Expand All @@ -38,33 +42,154 @@ class RedshiftConnectionMethod(StrEnum):


@dataclass
class RedshiftCredentials(PostgresCredentials):
class RedshiftCredentials(Credentials):
host: str
user: str
port: Port
method: str = RedshiftConnectionMethod.DATABASE # type: ignore
password: Optional[str] = None # type: ignore
cluster_id: Optional[str] = field(
default=None,
metadata={"description": "If using IAM auth, the name of the cluster"},
)
iam_profile: Optional[str] = None
iam_duration_seconds: int = 900
search_path: Optional[str] = None
keepalives_idle: int = 4
autocreate: bool = False
db_groups: List[str] = field(default_factory=list)
ra3_node: Optional[bool] = False
connect_timeout: int = 30
role: Optional[str] = None
sslmode: Optional[str] = None
retries: int = 1

_ALIASES = {"dbname": "database", "pass": "password"}

@property
def type(self):
return "redshift"

def _connection_keys(self):
keys = super()._connection_keys()
return keys + ("method", "cluster_id", "iam_profile", "iam_duration_seconds")
return "host", "port", "user", "database", "schema", "method", "cluster_id", "iam_profile"

@property
def unique_field(self) -> str:
return self.host


class RedshiftConnectMethodFactory:
credentials: RedshiftCredentials

def __init__(self, credentials):
self.credentials = credentials

def get_connect_method(self):
method = self.credentials.method
kwargs = {
"host": self.credentials.host,
"database": self.credentials.database,
"port": self.credentials.port if self.credentials.port else 5439,
"auto_create": self.credentials.autocreate,
"db_groups": self.credentials.db_groups,
"region": self.credentials.host.split(".")[2],
"timeout": self.credentials.connect_timeout,
}
if self.credentials.sslmode:
kwargs["sslmode"] = self.credentials.sslmode

# Support missing 'method' for backwards compatibility
if method == RedshiftConnectionMethod.DATABASE or method is None:
# this requirement is really annoying to encode into json schema,
# so validate it here
if self.credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
"'password' field is required for 'database' credentials"
)

def connect():
logger.debug("Connecting to redshift with username/password based auth...")
c = redshift_connector.connect(
user=self.credentials.user, password=self.credentials.password, **kwargs
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

return connect

class RedshiftConnectionManager(PostgresConnectionManager):
elif method == RedshiftConnectionMethod.IAM:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise dbt.exceptions.FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)

def connect():
logger.debug("Connecting to redshift with IAM based auth...")
c = redshift_connector.connect(
iam=True,
db_user=self.credentials.user,
password="",
user="",
cluster_identifier=self.credentials.cluster_id,
profile=self.credentials.iam_profile,
**kwargs,
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

return connect
else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)


class RedshiftConnectionManager(SQLConnectionManager):
TYPE = "redshift"

def _get_backend_pid(self):
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)
res = cursor.fetchone()
return res

def cancel(self, connection: Connection):
connection_name = connection.name
try:
pid = self._get_backend_pid()
sql = "select pg_terminate_backend({})".format(pid)
_, cursor = self.add_query(sql)
res = cursor.fetchone()
logger.debug("Cancel query '{}': {}".format(connection_name, res))
except redshift_connector.error.InterfaceError as e:
if "is closed" in str(e):
logger.debug(f"Connection {connection_name} was already closed")
return
raise

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
rows = cursor.rowcount
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@colin-rogers-dbt to follow-up on the removal of cursor.statusmessage in get_response

message = f"cursor.rowcount = {rows}"
return AdapterResponse(_message=message, rows_affected=rows)

@contextmanager
def exception_handler(self, sql):
try:
yield
except redshift_connector.error.DatabaseError as e:
logger.debug(f"Redshift error: {str(e)}")
self.rollback_if_open()
raise dbt.exceptions.DbtDatabaseError(str(e))
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
# Raise DBT native exceptions as is.
if isinstance(e, dbt.exceptions.Exception):
raise
raise dbt.exceptions.DbtRuntimeError(str(e)) from e

@contextmanager
def fresh_transaction(self, name=None):
"""On entrance to this context manager, hold an exclusive lock and
Expand All @@ -89,83 +214,68 @@ def fresh_transaction(self, name=None):
self.begin()

@classmethod
def fetch_cluster_credentials(
cls, db_user, db_name, cluster_id, iam_profile, duration_s, autocreate, db_groups
):
"""Fetches temporary login credentials from AWS. The specified user
must already exist in the database, or else an error will occur"""

if iam_profile is None:
session = boto3.Session()
boto_client = session.client("redshift")
def open(cls, connection):
if connection.state == "open":
logger.debug("Connection is already open, skipping open.")
return connection

credentials = connection.credentials
connect_method_factory = RedshiftConnectMethodFactory(credentials)

def exponential_backoff(attempt: int):
return attempt * attempt

retryable_exceptions = [OperationalError, DatabaseError, DataError]

return cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[AdapterResponse, agate.Table]:
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor)
else:
logger.debug("Connecting to Redshift using 'IAM'" + f"with profile {iam_profile}")
boto_session = boto3.Session(profile_name=iam_profile)
boto_client = boto_session.client("redshift")
table = dbt.clients.agate_helper.empty_table()
return response, table

try:
return boto_client.get_cluster_credentials(
DbUser=db_user,
DbName=db_name,
ClusterIdentifier=cluster_id,
DurationSeconds=duration_s,
AutoCreate=autocreate,
DbGroups=db_groups,
)
def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):

except boto_client.exceptions.ClientError as e:
raise dbt.exceptions.FailedToConnectError(
"Unable to get temporary Redshift cluster credentials: {}".format(e)
)
connection = None
cursor = None

@classmethod
def get_tmp_iam_cluster_credentials(cls, credentials):
cluster_id = credentials.cluster_id
queries = sqlparse.split(sql)

# default via:
# boto3.readthedocs.io/en/latest/reference/services/redshift.html
iam_duration_s = credentials.iam_duration_seconds
for query in queries:
# Strip off comments from the current query
without_comments = re.sub(
re.compile(r"(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)", re.MULTILINE),
"",
query,
).strip()

if not cluster_id:
raise dbt.exceptions.FailedToConnectError(
"'cluster_id' must be provided in profile if IAM " "authentication method selected"
if without_comments == "":
continue

connection, cursor = super().add_query(
query, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log
)

cluster_creds = cls.fetch_cluster_credentials(
credentials.user,
credentials.database,
credentials.cluster_id,
credentials.iam_profile,
iam_duration_s,
credentials.autocreate,
credentials.db_groups,
)
if cursor is None:
conn = self.get_thread_connection()
conn_name = conn.name if conn and conn.name else "<None>"
raise dbt.exceptions.DbtRuntimeError(f"Tried to run invalid SQL: {sql} on {conn_name}")

# replace username and password with temporary redshift credentials
return credentials.replace(
user=cluster_creds.get("DbUser"), password=cluster_creds.get("DbPassword")
)
return connection, cursor

@classmethod
def get_credentials(cls, credentials):
method = credentials.method

# Support missing 'method' for backwards compatibility
if method == "database" or method is None:
logger.debug("Connecting to Redshift using 'database' credentials")
# this requirement is really annoying to encode into json schema,
# so validate it here
if credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
"'password' field is required for 'database' credentials"
)
return credentials

elif method == "iam":
logger.debug("Connecting to Redshift using 'IAM' credentials")
return cls.get_tmp_iam_cluster_credentials(credentials)

else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)
return credentials
12 changes: 10 additions & 2 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dbt.adapters.base.impl import AdapterConfig
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.base.meta import available
from dbt.adapters.postgres import PostgresAdapter
from dbt.adapters.redshift import RedshiftConnectionManager
from dbt.adapters.redshift.column import RedshiftColumn
from dbt.adapters.redshift import RedshiftRelation
Expand All @@ -22,7 +21,7 @@ class RedshiftConfig(AdapterConfig):
backup: Optional[bool] = True


class RedshiftAdapter(PostgresAdapter, SQLAdapter):
class RedshiftAdapter(SQLAdapter):
Relation = RedshiftRelation
ConnectionManager = RedshiftConnectionManager
Column = RedshiftColumn # type: ignore
Expand Down Expand Up @@ -91,3 +90,12 @@ def _get_catalog_schemas(self, manifest):
self.type(), exc.msg
)
)

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
"""
return ["append", "delete+insert"]

def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
return f"{add_to} + interval '{number} {interval}'"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _core_version(plugin_version: str = _plugin_version()) -> str:
f"dbt-core~={_core_version()}",
f"dbt-postgres~={_core_version()}",
"boto3~=1.26.26",
"redshift-connector~=2.0.910",
],
zip_safe=False,
classifiers=[
Expand Down
Loading