From 8e4cd96d43092fb706db66b0e70d54cd27aac8a8 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Wed, 21 Feb 2024 14:30:01 -0600 Subject: [PATCH 01/10] Save backend_pid on initiation --- dbt/adapters/redshift/connections.py | 40 +++++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index b0fc0825d..7476cce95 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -156,6 +156,32 @@ def _connection_keys(self): def unique_field(self) -> str: return self.host +class RedshiftSQLConnectionWrapper: + """Wrap a Redshift SQL connector in a way that stores backend pid""" + + _conn: redshift_connector.Connection + _backend_pid: int + + def __init__( + self, + conn: redshift_connector.Connection + ): + self._conn = conn + self._backend_pid = self._get_backend_pid() + + def __getattr__(self, name): + return getattr(self._conn, name) + + def _get_backend_pid(self): + sql = "select pg_backend_pid()" + cursor = self.cursor().execute(sql) + res = cursor.fetchone() + return res[0] + + @property + def backend_pid(self): + return self._backend_pid + class RedshiftConnectMethodFactory: credentials: RedshiftCredentials @@ -194,6 +220,7 @@ def connect(): password=self.credentials.password, **kwargs, ) + c = RedshiftSQLConnectionWrapper(c) if self.credentials.autocommit: c.autocommit = True if self.credentials.role: @@ -218,6 +245,7 @@ def connect(): profile=self.credentials.iam_profile, **kwargs, ) + c = RedshiftSQLConnectionWrapper(c) if self.credentials.autocommit: c.autocommit = True if self.credentials.role: @@ -233,16 +261,9 @@ def connect(): class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" - def _get_backend_pid(self): - sql = "select pg_backend_pid()" - _, cursor = self.add_query(sql) - - res = cursor.fetchone() - return res[0] - def cancel(self, connection: Connection): try: - pid = self._get_backend_pid() + pid = connection.handle.backend_pid except redshift_connector.InterfaceError as e: if "is closed" in str(e): logger.debug(f"Connection {connection.name} was already closed") @@ -250,10 +271,9 @@ def cancel(self, connection: Connection): raise sql = f"select pg_terminate_backend({pid})" - cursor = connection.handle.cursor() logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") logger.debug(sql) - cursor.execute(sql) + self.add_query(sql) @classmethod def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: From 5fad97640a9909a8b689c064fd83f789e8eb1651 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:07:37 -0600 Subject: [PATCH 02/10] Adjust OnConfigurationChangeOption import --- .../adapter/materialized_view_tests/test_materialized_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/adapter/materialized_view_tests/test_materialized_views.py b/tests/functional/adapter/materialized_view_tests/test_materialized_views.py index 63bcede61..648df81a3 100644 --- a/tests/functional/adapter/materialized_view_tests/test_materialized_views.py +++ b/tests/functional/adapter/materialized_view_tests/test_materialized_views.py @@ -3,7 +3,7 @@ import pytest from dbt.adapters.base.relation import BaseRelation -from dbt.contracts.graph.model_config import OnConfigurationChangeOption +from dbt_common.contracts.config.materialization import OnConfigurationChangeOption from dbt.tests.adapter.materialized_view.basic import MaterializedViewBasic from dbt.tests.adapter.materialized_view.changes import ( From ef4e9fee80526e054c1d1c50e771de6f10b9fe29 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:08:33 -0600 Subject: [PATCH 03/10] Use MagicMock to accomodate __getitem__ call on cursor results --- tests/unit/test_redshift_adapter.py | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 671e47032..30cbde857 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -4,7 +4,7 @@ from unittest import mock from dbt_common.exceptions import DbtRuntimeError -from unittest.mock import Mock, call +from unittest.mock import MagicMock, call import agate import dbt @@ -67,7 +67,7 @@ def adapter(self): inject_adapter(self._adapter, RedshiftPlugin) return self._adapter - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_implicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle @@ -84,7 +84,7 @@ def test_implicit_database_conn(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region_with_database_conn(self): self.config.method = "database" @@ -103,7 +103,7 @@ def test_explicit_region_with_database_conn(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -129,7 +129,7 @@ def test_explicit_iam_conn_without_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_conn_timeout_30(self): self.config.credentials = self.config.credentials.replace(connect_timeout=30) connection = self.adapter.acquire_connection("dummy") @@ -147,7 +147,7 @@ def test_conn_timeout_30(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -175,7 +175,7 @@ def test_explicit_iam_conn_with_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_iam_serverless_with_profile(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -201,7 +201,7 @@ def test_explicit_iam_serverless_with_profile(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region(self): # Successful test self.config.credentials = self.config.credentials.replace( @@ -229,7 +229,7 @@ def test_explicit_region(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_region_failure(self): # Failure test with no region self.config.credentials = self.config.credentials.replace( @@ -258,7 +258,7 @@ def test_explicit_region_failure(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_explicit_invalid_region(self): # Invalid region test self.config.credentials = self.config.credentials.replace( @@ -287,7 +287,7 @@ def test_explicit_invalid_region(self): **DEFAULT_SSL_CONFIG, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_disable(self): self.config.credentials.sslmode = "disable" connection = self.adapter.acquire_connection("dummy") @@ -306,7 +306,7 @@ def test_sslmode_disable(self): sslmode=None, ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_allow(self): self.config.credentials.sslmode = "allow" connection = self.adapter.acquire_connection("dummy") @@ -325,7 +325,7 @@ def test_sslmode_allow(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_verify_full(self): self.config.credentials.sslmode = "verify-full" connection = self.adapter.acquire_connection("dummy") @@ -344,7 +344,7 @@ def test_sslmode_verify_full(self): sslmode="verify-full", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_verify_ca(self): self.config.credentials.sslmode = "verify-ca" connection = self.adapter.acquire_connection("dummy") @@ -363,7 +363,7 @@ def test_sslmode_verify_ca(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_sslmode_prefer(self): self.config.credentials.sslmode = "prefer" connection = self.adapter.acquire_connection("dummy") @@ -382,7 +382,7 @@ def test_sslmode_prefer(self): sslmode="verify-ca", ) - @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("redshift_connector.connect", MagicMock()) def test_serverless_iam_failure(self): self.config.credentials = self.config.credentials.replace( method="iam", From 0b7ac0e2165558a6422eb2c90733ddf901e55eb4 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:09:46 -0600 Subject: [PATCH 04/10] Test backend_pid usage --- tests/unit/test_redshift_adapter.py | 43 +++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 30cbde857..8268f8683 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -447,6 +447,23 @@ def test_invalid_iam_no_cluster_id(self): self.assertTrue("'cluster_id' must be provided" in context.exception.msg) + @mock.patch("redshift_connector.connect", MagicMock()) + def test_connection_has_backend_pid(self): + backend_pid = 42 + + cursor = mock.Mock() + cursor().execute().fetchone.return_value = (backend_pid,) + redshift_connector.connect().cursor = cursor + + connection = self.adapter.acquire_connection("dummy") + assert connection.handle.backend_pid == backend_pid + + cursor().execute.assert_has_calls( + [ + call("select pg_backend_pid()"), + ] + ) + def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) @@ -475,11 +492,33 @@ def test_cancel_open_connections_single(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) add_query.assert_has_calls( [ - call("select pg_backend_pid()"), + call(f"select pg_terminate_backend({model.handle.backend_pid})"), ] ) - master.handle.get_backend_pid.assert_not_called() + master.handle.backend_pid.assert_not_called() + + + @mock.patch("redshift_connector.connect", MagicMock()) + def test_backend_pid_used_in_pg_terminate_backend(self): + with mock.patch.object(self.adapter.connections, "add_query") as add_query: + backend_pid = 42 + query_result = (backend_pid,) + + cursor = mock.Mock() + cursor().execute().fetchone.return_value = query_result + redshift_connector.connect().cursor = cursor + + connection = self.adapter.acquire_connection("dummy") + connection.handle + + self.adapter.connections.cancel(connection) + + add_query.assert_has_calls( + [ + call(f"select pg_terminate_backend({backend_pid})"), + ] + ) def test_dbname_verification_is_case_insensitive(self): # Override adapter settings from setUp() From bcdb2f5ee05a00f17a5b8f6d9c5169ccf8ecc9f1 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:12:10 -0600 Subject: [PATCH 05/10] Lint --- dbt/adapters/redshift/connections.py | 8 +++----- tests/unit/test_redshift_adapter.py | 3 +-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 7476cce95..ea785e6b6 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -156,21 +156,19 @@ def _connection_keys(self): def unique_field(self) -> str: return self.host + class RedshiftSQLConnectionWrapper: """Wrap a Redshift SQL connector in a way that stores backend pid""" _conn: redshift_connector.Connection _backend_pid: int - def __init__( - self, - conn: redshift_connector.Connection - ): + def __init__(self, conn: redshift_connector.Connection): self._conn = conn self._backend_pid = self._get_backend_pid() def __getattr__(self, name): - return getattr(self._conn, name) + return getattr(self._conn, name) def _get_backend_pid(self): sql = "select pg_backend_pid()" diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 8268f8683..18953bb7f 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -462,7 +462,7 @@ def test_connection_has_backend_pid(self): [ call("select pg_backend_pid()"), ] - ) + ) def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) @@ -498,7 +498,6 @@ def test_cancel_open_connections_single(self): master.handle.backend_pid.assert_not_called() - @mock.patch("redshift_connector.connect", MagicMock()) def test_backend_pid_used_in_pg_terminate_backend(self): with mock.patch.object(self.adapter.connections, "add_query") as add_query: From 71a9b7f32a84b33c644f09127f642a64ba3fd88b Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:46:36 -0600 Subject: [PATCH 06/10] Fix merge conflict --- .../adapter/materialized_view_tests/test_materialized_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/adapter/materialized_view_tests/test_materialized_views.py b/tests/functional/adapter/materialized_view_tests/test_materialized_views.py index 648df81a3..64f697e77 100644 --- a/tests/functional/adapter/materialized_view_tests/test_materialized_views.py +++ b/tests/functional/adapter/materialized_view_tests/test_materialized_views.py @@ -3,7 +3,7 @@ import pytest from dbt.adapters.base.relation import BaseRelation -from dbt_common.contracts.config.materialization import OnConfigurationChangeOption +from dbt.adapters.contracts.relation import OnConfigurationChangeOption from dbt.tests.adapter.materialized_view.basic import MaterializedViewBasic from dbt.tests.adapter.materialized_view.changes import ( From 8cbd4ef40e9ab495703cb58866cf28c455b28c4e Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Tue, 26 Mar 2024 12:37:37 -0500 Subject: [PATCH 07/10] Add changelog entry --- .changes/unreleased/Fixes-20240326-123703.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Fixes-20240326-123703.yaml diff --git a/.changes/unreleased/Fixes-20240326-123703.yaml b/.changes/unreleased/Fixes-20240326-123703.yaml new file mode 100644 index 000000000..5d9bee694 --- /dev/null +++ b/.changes/unreleased/Fixes-20240326-123703.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: dbt can cancel open queries upon interrupt +time: 2024-03-26T12:37:03.17481-05:00 +custom: + Author: holly-evans + Issue: "705" From f53c1fe6cdf5aa83c0179abf94d387f6735b5219 Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 12 Apr 2024 09:10:13 -0500 Subject: [PATCH 08/10] Store backend_pid on connection directly --- dbt/adapters/redshift/connections.py | 47 +++++++++------------------- tests/unit/test_redshift_adapter.py | 16 +++++----- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index ea785e6b6..a008219aa 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -157,30 +157,6 @@ def unique_field(self) -> str: return self.host -class RedshiftSQLConnectionWrapper: - """Wrap a Redshift SQL connector in a way that stores backend pid""" - - _conn: redshift_connector.Connection - _backend_pid: int - - def __init__(self, conn: redshift_connector.Connection): - self._conn = conn - self._backend_pid = self._get_backend_pid() - - def __getattr__(self, name): - return getattr(self._conn, name) - - def _get_backend_pid(self): - sql = "select pg_backend_pid()" - cursor = self.cursor().execute(sql) - res = cursor.fetchone() - return res[0] - - @property - def backend_pid(self): - return self._backend_pid - - class RedshiftConnectMethodFactory: credentials: RedshiftCredentials @@ -218,7 +194,6 @@ def connect(): password=self.credentials.password, **kwargs, ) - c = RedshiftSQLConnectionWrapper(c) if self.credentials.autocommit: c.autocommit = True if self.credentials.role: @@ -243,7 +218,6 @@ def connect(): profile=self.credentials.iam_profile, **kwargs, ) - c = RedshiftSQLConnectionWrapper(c) if self.credentials.autocommit: c.autocommit = True if self.credentials.role: @@ -260,18 +234,25 @@ class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" def cancel(self, connection: Connection): + pid = connection.backend_pid + sql = f"select pg_terminate_backend({pid})" + logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") + logger.debug(sql) + try: - pid = connection.handle.backend_pid + self.add_query(sql) except redshift_connector.InterfaceError as e: if "is closed" in str(e): logger.debug(f"Connection {connection.name} was already closed") return raise - sql = f"select pg_terminate_backend({pid})" - logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") - logger.debug(sql) - self.add_query(sql) + @classmethod + def _get_backend_pid(cls, connection): + with connection.handle.cursor() as c: + sql = "select pg_backend_pid()" + res = c.execute(sql).fetchone() + return res[0] @classmethod def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: @@ -343,7 +324,7 @@ def exponential_backoff(attempt: int): redshift_connector.DataError, ] - return cls.retry_connection( + open_connection = cls.retry_connection( connection, connect=connect_method_factory.get_connect_method(), logger=logger, @@ -351,6 +332,8 @@ def exponential_backoff(attempt: int): retry_timeout=exponential_backoff, retryable_exceptions=retryable_exceptions, ) + open_connection.backend_pid = cls._get_backend_pid(open_connection) + return open_connection def execute( self, diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 18953bb7f..0bd5f8e99 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -451,14 +451,16 @@ def test_invalid_iam_no_cluster_id(self): def test_connection_has_backend_pid(self): backend_pid = 42 - cursor = mock.Mock() - cursor().execute().fetchone.return_value = (backend_pid,) + cursor = mock.MagicMock() + execute = cursor().__enter__().execute + execute().fetchone.return_value = (backend_pid,) redshift_connector.connect().cursor = cursor connection = self.adapter.acquire_connection("dummy") - assert connection.handle.backend_pid == backend_pid + connection.handle + assert connection.backend_pid == backend_pid - cursor().execute.assert_has_calls( + execute.assert_has_calls( [ call("select pg_backend_pid()"), ] @@ -492,7 +494,7 @@ def test_cancel_open_connections_single(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) add_query.assert_has_calls( [ - call(f"select pg_terminate_backend({model.handle.backend_pid})"), + call(f"select pg_terminate_backend({model.backend_pid})"), ] ) @@ -504,8 +506,8 @@ def test_backend_pid_used_in_pg_terminate_backend(self): backend_pid = 42 query_result = (backend_pid,) - cursor = mock.Mock() - cursor().execute().fetchone.return_value = query_result + cursor = mock.MagicMock() + cursor().__enter__().execute().fetchone.return_value = query_result redshift_connector.connect().cursor = cursor connection = self.adapter.acquire_connection("dummy") From 789e22b001f67e9890cdefda9bab4e6f6cafbacf Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 12 Apr 2024 15:08:51 -0400 Subject: [PATCH 09/10] mypy lint fixes --- dbt/adapters/redshift/connections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index a008219aa..4cfc7fe20 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -234,7 +234,7 @@ class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" def cancel(self, connection: Connection): - pid = connection.backend_pid + pid = connection.backend_pid # type: ignore sql = f"select pg_terminate_backend({pid})" logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") logger.debug(sql) @@ -332,7 +332,7 @@ def exponential_backoff(attempt: int): retry_timeout=exponential_backoff, retryable_exceptions=retryable_exceptions, ) - open_connection.backend_pid = cls._get_backend_pid(open_connection) + open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore return open_connection def execute( From 0d9dd486d75cde0f5cea978ca4ca20bb96ac483e Mon Sep 17 00:00:00 2001 From: Holly Evans <39742776+holly-evans@users.noreply.github.com> Date: Fri, 12 Apr 2024 17:01:48 -0500 Subject: [PATCH 10/10] Lint! --- dbt/adapters/redshift/connections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 4cfc7fe20..cc58c02a6 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -234,7 +234,7 @@ class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" def cancel(self, connection: Connection): - pid = connection.backend_pid # type: ignore + pid = connection.backend_pid # type: ignore sql = f"select pg_terminate_backend({pid})" logger.debug(f"Cancel query on: '{connection.name}' with PID: {pid}") logger.debug(sql) @@ -332,7 +332,7 @@ def exponential_backoff(attempt: int): retry_timeout=exponential_backoff, retryable_exceptions=retryable_exceptions, ) - open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore + open_connection.backend_pid = cls._get_backend_pid(open_connection) # type: ignore return open_connection def execute(