From ab506cdec25100f3db1943d79ce8cf15e4161fa2 Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Tue, 19 Nov 2024 20:30:29 +0100 Subject: [PATCH] Add require_certificate_validation Behavior Flag --- .../Under the Hood-20241120-191809.yaml | 7 ++ dbt/adapters/trino/connections.py | 13 +++- dbt/adapters/trino/impl.py | 21 ++++++ .../test_require_certificate_validation.py | 71 +++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Under the Hood-20241120-191809.yaml create mode 100644 tests/functional/adapter/behavior_flags/test_require_certificate_validation.py diff --git a/.changes/unreleased/Under the Hood-20241120-191809.yaml b/.changes/unreleased/Under the Hood-20241120-191809.yaml new file mode 100644 index 00000000..9f92e281 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241120-191809.yaml @@ -0,0 +1,7 @@ +kind: Under the Hood +body: Revert cert default to False. Add require_certificate_validation Behavior Flag +time: 2024-11-20T19:18:09.725288+01:00 +custom: + Author: damian3031 + Issue: "" + PR: "447" diff --git a/dbt/adapters/trino/connections.py b/dbt/adapters/trino/connections.py index f2d49ad9..ec142276 100644 --- a/dbt/adapters/trino/connections.py +++ b/dbt/adapters/trino/connections.py @@ -423,6 +423,12 @@ class TrinoAdapterResponse(AdapterResponse): class TrinoConnectionManager(SQLConnectionManager): TYPE = "trino" + behavior_flags = None + + def __init__(self, profile, mp_context, behavior_flags=None) -> None: + super().__init__(profile, mp_context) + + TrinoConnectionManager.behavior_flags = behavior_flags @contextmanager def exception_handler(self, sql): @@ -464,6 +470,11 @@ def open(cls, connection): return connection credentials = connection.credentials + verify = ( + credentials.cert + if credentials.cert is not None + else cls.behavior_flags.require_certificate_validation.setting + ) # it's impossible for trino to fail here as 'connections' are actually # just cursor factories. @@ -484,7 +495,7 @@ def open(cls, connection): max_attempts=credentials.retries, isolation_level=IsolationLevel.AUTOCOMMIT, source=f"dbt-trino-{version}", - verify=credentials.cert, + verify=verify, timezone=credentials.timezone, ) connection.state = "open" diff --git a/dbt/adapters/trino/impl.py b/dbt/adapters/trino/impl.py index d1d87c2a..3de84cbd 100644 --- a/dbt/adapters/trino/impl.py +++ b/dbt/adapters/trino/impl.py @@ -10,6 +10,7 @@ Support, ) from dbt.adapters.sql import SQLAdapter +from dbt_common.behavior_flags import BehaviorFlag from dbt_common.contracts.constraints import ConstraintType from dbt_common.exceptions import DbtDatabaseError @@ -47,6 +48,26 @@ class TrinoAdapter(SQLAdapter): } ) + def __init__(self, config, mp_context) -> None: + super().__init__(config, mp_context) + self.connections = self.ConnectionManager(config, mp_context, self.behavior) + + @property + def _behavior_flags(self) -> list[BehaviorFlag]: + return [ + { # type: ignore + "name": "require_certificate_validation", + "default": False, + "description": ( + "SSL certificate validation is disabled by default. " + "It is legacy behavior which will be changed in future releases. " + "It is strongly advised to enable `require_certificate_validation` flag " + "or explicitly set `cert` configuration to `True` for security reasons. " + "You may receive an error after that if your SSL setup is incorrect." + ), + } + ] + @classmethod def date_function(cls): return "datenow()" diff --git a/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py b/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py new file mode 100644 index 00000000..ac089ebd --- /dev/null +++ b/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py @@ -0,0 +1,71 @@ +import warnings + +import pytest +from dbt.tests.util import run_dbt, run_dbt_and_capture +from urllib3.exceptions import InsecureRequestWarning + + +class TestRequireCertificateValidationDefault: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if any InsecureRequestWarning was raised + assert any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised" + + +class TestRequireCertificateValidationFalse: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"require_certificate_validation": False}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if any InsecureRequestWarning was raised + assert any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised" + + +class TestRequireCertificateValidationTrue: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"require_certificate_validation": True}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" not in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if not any InsecureRequestWarning was raised + assert not any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised"