diff --git a/.changes/unreleased/Fixes-20240326-123703.yaml b/.changes/unreleased/Fixes-20240326-123703.yaml new file mode 100644 index 000000000..5d9bee694 --- /dev/null +++ b/.changes/unreleased/Fixes-20240326-123703.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: dbt can cancel open queries upon interrupt +time: 2024-03-26T12:37:03.17481-05:00 +custom: + Author: holly-evans + Issue: "705" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index b0fc0825d..cc58c02a6 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -233,27 +233,26 @@ def connect(): class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" - def _get_backend_pid(self): - sql = "select pg_backend_pid()" - _, cursor = self.add_query(sql) - - res = cursor.fetchone() - return res[0] - def cancel(self, connection: Connection): + pid = connection.backend_pid # type: ignore + sql = f"select pg_terminate_backend({pid})" + logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") + logger.debug(sql) + try: - pid = self._get_backend_pid() + self.add_query(sql) except redshift_connector.InterfaceError as e: if "is closed" in str(e): logger.debug(f"Connection {connection.name} was already closed") return raise - sql = f"select pg_terminate_backend({pid})" - cursor = connection.handle.cursor() - logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") - logger.debug(sql) - cursor.execute(sql) + @classmethod + def _get_backend_pid(cls, connection): + with connection.handle.cursor() as c: + sql = "select pg_backend_pid()" + res = c.execute(sql).fetchone() + return res[0] @classmethod def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: @@ -325,7 +324,7 @@ def exponential_backoff(attempt: int): redshift_connector.DataError, ] - return cls.retry_connection( + open_connection = cls.retry_connection( connection, connect=connect_method_factory.get_connect_method(), logger=logger, @@ -333,6 +332,8 @@ def exponential_backoff(attempt: int): retry_timeout=exponential_backoff, retryable_exceptions=retryable_exceptions, ) + open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore + return open_connection def execute( self, diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 671e47032..0bd5f8e99 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -4,7 +4,7 @@ from unittest import mock from dbt_common.exceptions import DbtRuntimeError -from unittest.mock import Mock, call +from unittest.mock import MagicMock, call import agate import dbt @@ -67,7 +67,7 @@ def adapter(self): inject_adapter(self._adapter, RedshiftPlugin) return self._adapter - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_implicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle @@ -84,7 +84,7 @@ def test_implicit_database_conn(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region_with_database_conn(self): self.config.method = "database" @@ -103,7 +103,7 @@ def test_explicit_region_with_database_conn(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -129,7 +129,7 @@ def test_explicit_iam_conn_without_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_conn_timeout_30(self): self.config.credentials = self.config.credentials.replace(connect_timeout=30) connection = self.adapter.acquire_connection("dummy") @@ -147,7 +147,7 @@ def test_conn_timeout_30(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -175,7 +175,7 @@ def test_explicit_iam_conn_with_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_serverless_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -201,7 +201,7 @@ def test_explicit_iam_serverless_with_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region(self): # Successful test self.config.credentials = self.config.credentials.replace( @@ -229,7 +229,7 @@ def test_explicit_region(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region_failure(self): # Failure test with no region self.config.credentials = self.config.credentials.replace( @@ -258,7 +258,7 @@ def test_explicit_region_failure(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_invalid_region(self): # Invalid region test self.config.credentials = self.config.credentials.replace( @@ -287,7 +287,7 @@ def test_explicit_invalid_region(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_disable(self): self.config.credentials.sslmode = "disable" connection = self.adapter.acquire_connection("dummy") @@ -306,7 +306,7 @@ def test_sslmode_disable(self): sslmode=None, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_allow(self): self.config.credentials.sslmode = "allow" connection = self.adapter.acquire_connection("dummy") @@ -325,7 +325,7 @@ def test_sslmode_allow(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_verify_full(self): self.config.credentials.sslmode = "verify-full" connection = self.adapter.acquire_connection("dummy") @@ -344,7 +344,7 @@ def test_sslmode_verify_full(self): sslmode="verify-full", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_verify_ca(self): self.config.credentials.sslmode = "verify-ca" connection = self.adapter.acquire_connection("dummy") @@ -363,7 +363,7 @@ def test_sslmode_verify_ca(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_prefer(self): self.config.credentials.sslmode = "prefer" connection = self.adapter.acquire_connection("dummy") @@ -382,7 +382,7 @@ def test_sslmode_prefer(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_serverless_iam_failure(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -447,6 +447,25 @@ def test_invalid_iam_no_cluster_id(self): self.assertTrue("'cluster_id' must be provided" in context.exception.msg) + @mock.patch("redshift_connector.connect", MagicMock()) + def test_connection_has_backend_pid(self): + backend_pid = 42 + + cursor = mock.MagicMock() + execute = cursor().__enter__().execute + execute().fetchone.return_value = (backend_pid,) + redshift_connector.connect().cursor = cursor + + connection = self.adapter.acquire_connection("dummy") + connection.handle + assert connection.backend_pid == backend_pid + + execute.assert_has_calls( + [ + call("select pg_backend_pid()"), + ] + ) + def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) @@ -475,11 +494,32 @@ def test_cancel_open_connections_single(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) add_query.assert_has_calls( [ - call("select pg_backend_pid()"), + call(f"select pg_terminate_backend({model.backend_pid})"), ] ) - master.handle.get_backend_pid.assert_not_called() + master.handle.backend_pid.assert_not_called() + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_backend_pid_used_in_pg_terminate_backend(self): + with mock.patch.object(self.adapter.connections, "add_query") as add_query: + backend_pid = 42 + query_result = (backend_pid,) + + cursor = mock.MagicMock() + cursor().__enter__().execute().fetchone.return_value = query_result + redshift_connector.connect().cursor = cursor + + connection = self.adapter.acquire_connection("dummy") + connection.handle + + self.adapter.connections.cancel(connection) + + add_query.assert_has_calls( + [ + call(f"select pg_terminate_backend({backend_pid})"), + ] + ) def test_dbname_verification_is_case_insensitive(self): # Override adapter settings from setUp()