Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dbt to cancel connections #718

Merged
merged 18 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240326-123703.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: dbt can cancel open queries upon interrupt
time: 2024-03-26T12:37:03.17481-05:00
custom:
Author: holly-evans
Issue: "705"
29 changes: 15 additions & 14 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,27 +233,26 @@ def connect():
class RedshiftConnectionManager(SQLConnectionManager):
TYPE = "redshift"

def _get_backend_pid(self):
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)

res = cursor.fetchone()
return res[0]

def cancel(self, connection: Connection):
pid = connection.backend_pid
holly-evans marked this conversation as resolved.
Show resolved Hide resolved
holly-evans marked this conversation as resolved.
Show resolved Hide resolved
sql = f"select pg_terminate_backend({pid})"
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)

try:
pid = self._get_backend_pid()
self.add_query(sql)
except redshift_connector.InterfaceError as e:
if "is closed" in str(e):
logger.debug(f"Connection {connection.name} was already closed")
return
raise

sql = f"select pg_terminate_backend({pid})"
cursor = connection.handle.cursor()
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)
cursor.execute(sql)
@classmethod
def _get_backend_pid(cls, connection):
with connection.handle.cursor() as c:
sql = "select pg_backend_pid()"
res = c.execute(sql).fetchone()
return res[0]

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
Expand Down Expand Up @@ -325,14 +324,16 @@ def exponential_backoff(attempt: int):
redshift_connector.DataError,
]

return cls.retry_connection(
open_connection = cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)
open_connection.backend_pid = cls._get_backend_pid(open_connection)
holly-evans marked this conversation as resolved.
Show resolved Hide resolved
return open_connection

def execute(
self,
Expand Down
76 changes: 58 additions & 18 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import mock

from dbt_common.exceptions import DbtRuntimeError
from unittest.mock import Mock, call
from unittest.mock import MagicMock, call
mikealfare marked this conversation as resolved.
Show resolved Hide resolved

import agate
import dbt
Expand Down Expand Up @@ -67,7 +67,7 @@ def adapter(self):
inject_adapter(self._adapter, RedshiftPlugin)
return self._adapter

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_implicit_database_conn(self):
connection = self.adapter.acquire_connection("dummy")
connection.handle
Expand All @@ -84,7 +84,7 @@ def test_implicit_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_with_database_conn(self):
self.config.method = "database"

Expand All @@ -103,7 +103,7 @@ def test_explicit_region_with_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_without_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -129,7 +129,7 @@ def test_explicit_iam_conn_without_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_conn_timeout_30(self):
self.config.credentials = self.config.credentials.replace(connect_timeout=30)
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -147,7 +147,7 @@ def test_conn_timeout_30(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_explicit_iam_conn_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_serverless_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -201,7 +201,7 @@ def test_explicit_iam_serverless_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_explicit_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_failure(self):
# Failure test with no region
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_explicit_region_failure(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_invalid_region(self):
# Invalid region test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_explicit_invalid_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_disable(self):
self.config.credentials.sslmode = "disable"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -306,7 +306,7 @@ def test_sslmode_disable(self):
sslmode=None,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_allow(self):
self.config.credentials.sslmode = "allow"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -325,7 +325,7 @@ def test_sslmode_allow(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_full(self):
self.config.credentials.sslmode = "verify-full"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -344,7 +344,7 @@ def test_sslmode_verify_full(self):
sslmode="verify-full",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_ca(self):
self.config.credentials.sslmode = "verify-ca"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -363,7 +363,7 @@ def test_sslmode_verify_ca(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_prefer(self):
self.config.credentials.sslmode = "prefer"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -382,7 +382,7 @@ def test_sslmode_prefer(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_serverless_iam_failure(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -447,6 +447,25 @@ def test_invalid_iam_no_cluster_id(self):

self.assertTrue("'cluster_id' must be provided" in context.exception.msg)

@mock.patch("redshift_connector.connect", MagicMock())
def test_connection_has_backend_pid(self):
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
backend_pid = 42

cursor = mock.MagicMock()
execute = cursor().__enter__().execute
execute().fetchone.return_value = (backend_pid,)
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle
assert connection.backend_pid == backend_pid

execute.assert_has_calls(
[
call("select pg_backend_pid()"),
]
)

def test_cancel_open_connections_empty(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

Expand Down Expand Up @@ -475,11 +494,32 @@ def test_cancel_open_connections_single(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
add_query.assert_has_calls(
[
call("select pg_backend_pid()"),
call(f"select pg_terminate_backend({model.backend_pid})"),
]
)

master.handle.get_backend_pid.assert_not_called()
master.handle.backend_pid.assert_not_called()

@mock.patch("redshift_connector.connect", MagicMock())
def test_backend_pid_used_in_pg_terminate_backend(self):
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
with mock.patch.object(self.adapter.connections, "add_query") as add_query:
backend_pid = 42
query_result = (backend_pid,)

cursor = mock.MagicMock()
cursor().__enter__().execute().fetchone.return_value = query_result
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle

self.adapter.connections.cancel(connection)

add_query.assert_has_calls(
[
call(f"select pg_terminate_backend({backend_pid})"),
]
)

def test_dbname_verification_is_case_insensitive(self):
# Override adapter settings from setUp()
Expand Down
Loading