diff --git a/.changes/unreleased/Under the Hood-20241120-191809.yaml b/.changes/unreleased/Under the Hood-20241120-191809.yaml new file mode 100644 index 00000000..9f92e281 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241120-191809.yaml @@ -0,0 +1,7 @@ +kind: Under the Hood +body: Revert cert default to False. Add require_certificate_validation Behavior Flag +time: 2024-11-20T19:18:09.725288+01:00 +custom: + Author: damian3031 + Issue: "" + PR: "447" diff --git a/dbt/adapters/trino/connections.py b/dbt/adapters/trino/connections.py index f2d49ad9..a48b08d1 100644 --- a/dbt/adapters/trino/connections.py +++ b/dbt/adapters/trino/connections.py @@ -423,6 +423,12 @@ class TrinoAdapterResponse(AdapterResponse): class TrinoConnectionManager(SQLConnectionManager): TYPE = "trino" + behavior_flags = None + + def __init__(self, profile, mp_context, behavior_flags) -> None: + super().__init__(profile, mp_context) + + TrinoConnectionManager.behavior_flags = behavior_flags @contextmanager def exception_handler(self, sql): @@ -464,6 +470,11 @@ def open(cls, connection): return connection credentials = connection.credentials + verify = ( + credentials.cert + if credentials.cert is not None + else cls.behavior_flags.require_certificate_validation.setting + ) # it's impossible for trino to fail here as 'connections' are actually # just cursor factories. @@ -484,7 +495,7 @@ def open(cls, connection): max_attempts=credentials.retries, isolation_level=IsolationLevel.AUTOCOMMIT, source=f"dbt-trino-{version}", - verify=credentials.cert, + verify=verify, timezone=credentials.timezone, ) connection.state = "open" diff --git a/dbt/adapters/trino/impl.py b/dbt/adapters/trino/impl.py index d1d87c2a..b68bb834 100644 --- a/dbt/adapters/trino/impl.py +++ b/dbt/adapters/trino/impl.py @@ -3,6 +3,7 @@ import agate from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport +from dbt.adapters.cache import RelationsCache from dbt.adapters.capability import ( Capability, CapabilityDict, @@ -10,6 +11,7 @@ Support, ) from dbt.adapters.sql import SQLAdapter +from dbt_common.behavior_flags import BehaviorFlag from dbt_common.contracts.constraints import ConstraintType from dbt_common.exceptions import DbtDatabaseError @@ -47,6 +49,31 @@ class TrinoAdapter(SQLAdapter): } ) + def __init__(self, config, mp_context) -> None: + self.config = config + self.cache = RelationsCache(log_cache_events=config.log_cache_events) + # this will be updated to include global behavior flags once they exist + self.behavior = [] # type: ignore + self.connections = self.ConnectionManager(config, mp_context, self.behavior) + self._macro_resolver = None + self._macro_context_generator = None + + @property + def _behavior_flags(self) -> list[BehaviorFlag]: + return [ + { + "name": "require_certificate_validation", + "default": False, + "description": ( + "SSL certificate validation is disabled by default. " + "It is legacy behavior which will be changed in future releases. " + "It is strongly advised to enable `require_certificate_validation` flag " + "or explicitly set `cert` configuration to `True` for security reasons. " + "You may receive an error after that if your SSL setup is incorrect." + ), + } + ] + @classmethod def date_function(cls): return "datenow()" diff --git a/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py b/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py new file mode 100644 index 00000000..ac089ebd --- /dev/null +++ b/tests/functional/adapter/behavior_flags/test_require_certificate_validation.py @@ -0,0 +1,71 @@ +import warnings + +import pytest +from dbt.tests.util import run_dbt, run_dbt_and_capture +from urllib3.exceptions import InsecureRequestWarning + + +class TestRequireCertificateValidationDefault: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if any InsecureRequestWarning was raised + assert any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised" + + +class TestRequireCertificateValidationFalse: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"require_certificate_validation": False}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if any InsecureRequestWarning was raised + assert any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised" + + +class TestRequireCertificateValidationTrue: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"require_certificate_validation": True}} + + def test_require_certificate_validation_logs(self, project): + dbt_args = ["show", "--inline", "select 1"] + _, logs = run_dbt_and_capture(dbt_args) + assert "It is strongly advised to enable `require_certificate_validation` flag" not in logs + + @pytest.mark.skip_profile("trino_starburst") + def test_require_certificate_validation_insecure_request_warning(self, project): + with warnings.catch_warnings(record=True) as w: + dbt_args = ["show", "--inline", "select 1"] + run_dbt(dbt_args) + + # Check if not any InsecureRequestWarning was raised + assert not any( + issubclass(warning.category, InsecureRequestWarning) for warning in w + ), "InsecureRequestWarning was not raised"