Skip to content

Commit

Permalink
Allow dbt to cancel connections (#718) (#767)
Browse files Browse the repository at this point in the history
Co-authored-by: Holly Evans <[email protected]>
  • Loading branch information
mikealfare and holly-evans authored Apr 17, 2024
1 parent 56aaaa7 commit 74d391a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240326-123703.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: dbt can cancel open queries upon interrupt
time: 2024-03-26T12:37:03.17481-05:00
custom:
Author: holly-evans
Issue: "705"
29 changes: 15 additions & 14 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,27 +237,26 @@ def connect():
class RedshiftConnectionManager(SQLConnectionManager):
TYPE = "redshift"

def _get_backend_pid(self):
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)

res = cursor.fetchone()
return res[0]

def cancel(self, connection: Connection):
pid = connection.backend_pid # type: ignore
sql = f"select pg_terminate_backend({pid})"
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)

try:
pid = self._get_backend_pid()
self.add_query(sql)
except redshift_connector.InterfaceError as e:
if "is closed" in str(e):
logger.debug(f"Connection {connection.name} was already closed")
return
raise

sql = f"select pg_terminate_backend({pid})"
cursor = connection.handle.cursor()
logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}")
logger.debug(sql)
cursor.execute(sql)
@classmethod
def _get_backend_pid(cls, connection):
with connection.handle.cursor() as c:
sql = "select pg_backend_pid()"
res = c.execute(sql).fetchone()
return res[0]

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
Expand Down Expand Up @@ -327,14 +326,16 @@ def exponential_backoff(attempt: int):
redshift_connector.DataError,
]

return cls.retry_connection(
open_connection = cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)
open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore
return open_connection

def execute(
self,
Expand Down
76 changes: 58 additions & 18 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from unittest import mock
from unittest.mock import Mock, call
from unittest.mock import MagicMock, call

import agate
import dbt
Expand Down Expand Up @@ -63,7 +63,7 @@ def adapter(self):
inject_adapter(self._adapter, RedshiftPlugin)
return self._adapter

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_implicit_database_conn(self):
connection = self.adapter.acquire_connection("dummy")
connection.handle
Expand All @@ -80,7 +80,7 @@ def test_implicit_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_with_database_conn(self):
self.config.method = "database"

Expand All @@ -99,7 +99,7 @@ def test_explicit_region_with_database_conn(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_without_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -125,7 +125,7 @@ def test_explicit_iam_conn_without_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_conn_timeout_30(self):
self.config.credentials = self.config.credentials.replace(connect_timeout=30)
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -143,7 +143,7 @@ def test_conn_timeout_30(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_conn_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_explicit_iam_conn_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_iam_serverless_with_profile(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand All @@ -197,7 +197,7 @@ def test_explicit_iam_serverless_with_profile(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region(self):
# Successful test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_explicit_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_region_failure(self):
# Failure test with no region
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_explicit_region_failure(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_explicit_invalid_region(self):
# Invalid region test
self.config.credentials = self.config.credentials.replace(
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_explicit_invalid_region(self):
**DEFAULT_SSL_CONFIG,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_disable(self):
self.config.credentials.sslmode = "disable"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -302,7 +302,7 @@ def test_sslmode_disable(self):
sslmode=None,
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_allow(self):
self.config.credentials.sslmode = "allow"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -321,7 +321,7 @@ def test_sslmode_allow(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_full(self):
self.config.credentials.sslmode = "verify-full"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -340,7 +340,7 @@ def test_sslmode_verify_full(self):
sslmode="verify-full",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_verify_ca(self):
self.config.credentials.sslmode = "verify-ca"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -359,7 +359,7 @@ def test_sslmode_verify_ca(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_sslmode_prefer(self):
self.config.credentials.sslmode = "prefer"
connection = self.adapter.acquire_connection("dummy")
Expand All @@ -378,7 +378,7 @@ def test_sslmode_prefer(self):
sslmode="verify-ca",
)

@mock.patch("redshift_connector.connect", Mock())
@mock.patch("redshift_connector.connect", MagicMock())
def test_serverless_iam_failure(self):
self.config.credentials = self.config.credentials.replace(
method="iam",
Expand Down Expand Up @@ -443,6 +443,25 @@ def test_invalid_iam_no_cluster_id(self):

self.assertTrue("'cluster_id' must be provided" in context.exception.msg)

@mock.patch("redshift_connector.connect", MagicMock())
def test_connection_has_backend_pid(self):
backend_pid = 42

cursor = mock.MagicMock()
execute = cursor().__enter__().execute
execute().fetchone.return_value = (backend_pid,)
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle
assert connection.backend_pid == backend_pid

execute.assert_has_calls(
[
call("select pg_backend_pid()"),
]
)

def test_cancel_open_connections_empty(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

Expand Down Expand Up @@ -471,11 +490,32 @@ def test_cancel_open_connections_single(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1)
add_query.assert_has_calls(
[
call("select pg_backend_pid()"),
call(f"select pg_terminate_backend({model.backend_pid})"),
]
)

master.handle.get_backend_pid.assert_not_called()
master.handle.backend_pid.assert_not_called()

@mock.patch("redshift_connector.connect", MagicMock())
def test_backend_pid_used_in_pg_terminate_backend(self):
with mock.patch.object(self.adapter.connections, "add_query") as add_query:
backend_pid = 42
query_result = (backend_pid,)

cursor = mock.MagicMock()
cursor().__enter__().execute().fetchone.return_value = query_result
redshift_connector.connect().cursor = cursor

connection = self.adapter.acquire_connection("dummy")
connection.handle

self.adapter.connections.cancel(connection)

add_query.assert_has_calls(
[
call(f"select pg_terminate_backend({backend_pid})"),
]
)

def test_dbname_verification_is_case_insensitive(self):
# Override adapter settings from setUp()
Expand Down

0 comments on commit 74d391a

Please sign in to comment.