diff --git a/.changes/unreleased/Features-20240709-194316.yaml b/.changes/unreleased/Features-20240709-194316.yaml new file mode 100644 index 000000000..a867387e3 --- /dev/null +++ b/.changes/unreleased/Features-20240709-194316.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Improve run times for large projects by reusing connections by default +time: 2024-07-09T19:43:16.489649-04:00 +custom: + Author: mikealfare amardatar + Issue: "1082" diff --git a/.changes/unreleased/Features-20240710-172345.yaml b/.changes/unreleased/Features-20240710-172345.yaml new file mode 100644 index 000000000..e68f63812 --- /dev/null +++ b/.changes/unreleased/Features-20240710-172345.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Improve run times when using key pair auth by caching the private key +time: 2024-07-10T17:23:45.046905-04:00 +custom: + Author: mikealfare aranke + Issue: "1082" diff --git a/.changes/unreleased/Fixes-20240705-165932.yaml b/.changes/unreleased/Fixes-20240705-165932.yaml new file mode 100644 index 000000000..ffe902c92 --- /dev/null +++ b/.changes/unreleased/Fixes-20240705-165932.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Use show ... starts with instead of show ... like in _show_object_metadata +time: 2024-07-05T16:59:32.087555+01:00 +custom: + Author: aranke + Issue: "1102" diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index a73832cc1..0611bbfcd 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -183,6 +183,8 @@ jobs: SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN: ${{ secrets.SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN }} SNOWFLAKE_TEST_OAUTH_CLIENT_ID: ${{ secrets.SNOWFLAKE_TEST_OAUTH_CLIENT_ID }} SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET: ${{ secrets.SNOWFLAKE_TEST_OAUTH_CLIENT_SECRET }} + SNOWFLAKE_TEST_PRIVATE_KEY: ${{ secrets.SNOWFLAKE_TEST_PRIVATE_KEY }} + SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE: ${{ secrets.SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE }} SNOWFLAKE_TEST_ALT_DATABASE: ${{ secrets.SNOWFLAKE_TEST_ALT_DATABASE }} SNOWFLAKE_TEST_ALT_WAREHOUSE: ${{ secrets.SNOWFLAKE_TEST_ALT_WAREHOUSE }} SNOWFLAKE_TEST_DATABASE: ${{ secrets.SNOWFLAKE_TEST_DATABASE }} diff --git a/dbt/adapters/snowflake/auth.py b/dbt/adapters/snowflake/auth.py new file mode 100644 index 000000000..e914b6f3d --- /dev/null +++ b/dbt/adapters/snowflake/auth.py @@ -0,0 +1,57 @@ +import base64 +import sys +from typing import Optional + +if sys.version_info < (3, 9): + from functools import lru_cache + + cache = lru_cache(maxsize=None) +else: + from functools import cache + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + +@cache +def private_key_from_string( + private_key_string: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: + + if passphrase: + encoded_passphrase = passphrase.encode() + else: + encoded_passphrase = None + + if private_key_string.startswith("-"): + return serialization.load_pem_private_key( + data=bytes(private_key_string, "utf-8"), + password=encoded_passphrase, + backend=default_backend(), + ) + return serialization.load_der_private_key( + data=base64.b64decode(private_key_string), + password=encoded_passphrase, + backend=default_backend(), + ) + + +@cache +def private_key_from_file( + private_key_path: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: + + if passphrase: + encoded_passphrase = passphrase.encode() + else: + encoded_passphrase = None + + with open(private_key_path, "rb") as file: + private_key_bytes = file.read() + + return serialization.load_pem_private_key( + data=private_key_bytes, + password=encoded_passphrase, + backend=default_backend(), + ) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 568701de7..10bee30f0 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -1,6 +1,14 @@ import base64 import datetime import os +import sys + +if sys.version_info < (3, 9): + from functools import lru_cache + + cache = lru_cache(maxsize=None) +else: + from functools import cache import pytz import re @@ -11,8 +19,8 @@ from typing import Optional, Tuple, Union, Any, List, Iterable, TYPE_CHECKING -from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey import requests import snowflake.connector import snowflake.connector.constants @@ -46,6 +54,8 @@ from dbt_common.ui import line_wrap_message, warning_tag from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle +from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string + if TYPE_CHECKING: import agate @@ -65,6 +75,15 @@ } +@cache +def snowflake_private_key(private_key: RSAPrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + @dataclass class SnowflakeAdapterResponse(AdapterResponse): query_id: str = "" @@ -96,6 +115,7 @@ class SnowflakeCredentials(Credentials): retry_on_database_errors: bool = False retry_all: bool = False insecure_mode: Optional[bool] = False + # this needs to default to `None` so that we can tell if the user set it; see `__post_init__()` reuse_connections: Optional[bool] = None def __post_init__(self): @@ -126,6 +146,11 @@ def __post_init__(self): self.account = self.account.replace("_", "-") + # only default `reuse_connections` to `True` if the user has not turned on `client_session_keep_alive` + # having both of these set to `True` could lead to hanging open connections, so it should be opt-in behavior + if self.client_session_keep_alive is False and self.reuse_connections is None: + self.reuse_connections = True + @property def type(self): return "snowflake" @@ -275,44 +300,17 @@ def _get_access_token(self) -> str: ) return result_json["access_token"] - def _get_private_key(self): + def _get_private_key(self) -> Optional[bytes]: """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" if self.private_key and self.private_key_path: raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") - - if self.private_key_passphrase: - encoded_passphrase = self.private_key_passphrase.encode() - else: - encoded_passphrase = None - - if self.private_key: - if self.private_key.startswith("-"): - p_key = serialization.load_pem_private_key( - data=bytes(self.private_key, "utf-8"), - password=encoded_passphrase, - backend=default_backend(), - ) - - else: - p_key = serialization.load_der_private_key( - data=base64.b64decode(self.private_key), - password=encoded_passphrase, - backend=default_backend(), - ) - + elif self.private_key: + private_key = private_key_from_string(self.private_key, self.private_key_passphrase) elif self.private_key_path: - with open(self.private_key_path, "rb") as key: - p_key = serialization.load_pem_private_key( - key.read(), password=encoded_passphrase, backend=default_backend() - ) + private_key = private_key_from_file(self.private_key_path, self.private_key_passphrase) else: return None - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) + return snowflake_private_key(private_key) class SnowflakeConnectionManager(SQLConnectionManager): diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 092510e8a..6854b199d 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -156,7 +156,7 @@ def _show_object_metadata(self, relation: SnowflakeRelation) -> Optional[dict]: def get_catalog_for_single_relation( self, relation: SnowflakeRelation ) -> Optional[CatalogTable]: - object_metadata = self._show_object_metadata(relation) + object_metadata = self._show_object_metadata(relation.as_case_sensitive()) if not object_metadata: return None diff --git a/dbt/adapters/snowflake/relation.py b/dbt/adapters/snowflake/relation.py index f477265f0..ace85695b 100644 --- a/dbt/adapters/snowflake/relation.py +++ b/dbt/adapters/snowflake/relation.py @@ -2,12 +2,12 @@ from typing import FrozenSet, Optional, Type from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.contracts.relation import ComponentName, RelationConfig from dbt.adapters.relation_configs import ( RelationConfigBase, RelationConfigChangeAction, RelationResults, ) -from dbt.adapters.contracts.relation import RelationConfig from dbt.adapters.utils import classproperty from dbt_common.exceptions import DbtRuntimeError @@ -106,3 +106,17 @@ def dynamic_table_config_changeset( if config_change_collection.has_changes: return config_change_collection return None + + def as_case_sensitive(self) -> "SnowflakeRelation": + path_part_map = {} + + for path in ComponentName: + if self.include_policy.get_part(path): + part = self.path.get_part(path) + if part: + if self.quote_policy.get_part(path): + path_part_map[path] = part + else: + path_part_map[path] = part.upper() + + return self.replace_path(**path_part_map) diff --git a/dbt/include/snowflake/macros/adapters.sql b/dbt/include/snowflake/macros/adapters.sql index 177720486..4cb4bcffa 100644 --- a/dbt/include/snowflake/macros/adapters.sql +++ b/dbt/include/snowflake/macros/adapters.sql @@ -49,7 +49,7 @@ {% macro snowflake__show_object_metadata(relation) %} {%- set sql -%} - show objects like '{{ relation.identifier }}' in {{ relation.include(identifier=False) }} limit 1 + show objects in {{ relation.include(identifier=False) }} starts with '{{ relation.identifier }}' limit 1 {%- endset -%} {%- set result = run_query(sql) -%} diff --git a/tests/functional/oauth/test_jwt.py b/tests/functional/auth_tests/test_jwt.py similarity index 100% rename from tests/functional/oauth/test_jwt.py rename to tests/functional/auth_tests/test_jwt.py diff --git a/tests/functional/auth_tests/test_key_pair.py b/tests/functional/auth_tests/test_key_pair.py new file mode 100644 index 000000000..6d3254f33 --- /dev/null +++ b/tests/functional/auth_tests/test_key_pair.py @@ -0,0 +1,26 @@ +import os + +from dbt.tests.util import run_dbt +import pytest + + +class TestKeyPairAuth: + @pytest.fixture(scope="class", autouse=True) + def dbt_profile_target(self): + return { + "type": "snowflake", + "threads": 4, + "account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"), + "user": os.getenv("SNOWFLAKE_TEST_USER"), + "private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"), + "private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"), + "database": os.getenv("SNOWFLAKE_TEST_DATABASE"), + "warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), + } + + @pytest.fixture(scope="class") + def models(self): + return {"my_model.sql": "select 1 as id"} + + def test_connection(self, project): + run_dbt() diff --git a/tests/functional/oauth/test_oauth.py b/tests/functional/auth_tests/test_oauth.py similarity index 100% rename from tests/functional/oauth/test_oauth.py rename to tests/functional/auth_tests/test_oauth.py diff --git a/tests/performance/README.md b/tests/performance/README.md new file mode 100644 index 000000000..02130c5c6 --- /dev/null +++ b/tests/performance/README.md @@ -0,0 +1,6 @@ +# Performance testing + +These tests are not meant to run on a regular basis; instead, they are tools for measuring performance impacts of changes as needed. +We often get requests for reducing processing times, researching why a particular component is taking longer to run than expected, etc. +In the past we have performed one-off analyses to address these requests and documented the results in the relevant PR (when a change is made). +It is more useful to document those analyses in the form of performance tests so that we can easily rerun the analysis at a later date. diff --git a/tests/performance/test_auth_methods.py b/tests/performance/test_auth_methods.py new file mode 100644 index 000000000..ad0b424ab --- /dev/null +++ b/tests/performance/test_auth_methods.py @@ -0,0 +1,132 @@ +""" +Results: + +| method | project_size | reuse_connections | unsafe_skip_rsa_key_validation | duration | +|---------------|--------------|-------------------|--------------------------------|----------| +| User Password | 1,000 | False | - | 234.09s | +| User Password | 1,000 | True | - | 78.34s | +| Key Pair | 1,000 | False | False | 271.47s | +| Key Pair | 1,000 | False | True | 275.73s | +| Key Pair | 1,000 | True | False | 63.69s | +| Key Pair | 1,000 | True | True | 73.45s | + +Notes: +- run locally on MacOS, single threaded +- `unsafe_skip_rsa_key_validation` only applies to the Key Pair auth method +- `unsafe_skip_rsa_key_validation=True` was tested by updating the relevant `cryptography` calls directly as it is not a user configuration +- since the models are all views, time differences should be viewed as absolute differences, e.g.: + - this: (271.47s - 63.69s) / 1,000 models = 208ms improvement + - NOT this: 1 - (63.69s / 271.47s) = 76.7% improvement +""" + +from datetime import datetime +import os + +from dbt.tests.util import run_dbt +import pytest + + +SEED = """ +id,value +1,a +2,b +3,c +""".strip() + + +MODEL = """ +select * from {{ ref("my_seed") }} +""" + + +class Scenario: + """ + Runs a full load test. The test can be configured to run an arbitrary number of models. + + To use this test, configure the test by setting `project_size` and/or `expected_duration`. + """ + + auth_method: str + project_size: int = 1 + reuse_connections: bool = False + + @pytest.fixture(scope="class") + def seeds(self): + return {"my_seed.csv": SEED} + + @pytest.fixture(scope="class") + def models(self): + return {f"my_model_{i}.sql": MODEL for i in range(self.project_size)} + + @pytest.fixture(scope="class", autouse=True) + def setup(self, project): + run_dbt(["seed"]) + + start = datetime.now() + yield + end = datetime.now() + + duration = (end - start).total_seconds() + print(f"Run took: {duration} seconds") + + @pytest.fixture(scope="class") + def dbt_profile_target(self, auth_params): + yield { + "type": "snowflake", + "threads": 4, + "account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"), + "database": os.getenv("SNOWFLAKE_TEST_DATABASE"), + "warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), + "user": os.getenv("SNOWFLAKE_TEST_USER"), + "reuse_connections": self.reuse_connections, + **auth_params, + } + + @pytest.fixture(scope="class") + def auth_params(self): + + if self.auth_method == "user_password": + yield {"password": os.getenv("SNOWFLAKE_TEST_PASSWORD")} + + elif self.auth_method == "key_pair": + """ + This connection method uses key pair auth. + Follow the instructions here to setup key pair authentication for your test user: + https://docs.snowflake.com/en/user-guide/key-pair-auth + """ + yield { + "private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"), + "private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"), + } + + else: + raise ValueError( + f"`auth_method` must be one of `user_password` or `key_pair`, received: {self.auth_method}" + ) + + def test_scenario(self, project): + run_dbt(["run"]) + + +class TestUserPasswordAuth(Scenario): + auth_method = "user_password" + project_size = 1_000 + reuse_connections = False + + +class TestUserPasswordAuthReuseConnections(Scenario): + auth_method = "user_password" + project_size = 1_000 + reuse_connections = True + + +class TestKeyPairAuth(Scenario): + auth_method = "key_pair" + project_size = 1_000 + reuse_connections = False + + +class TestKeyPairAuthReuseConnections(Scenario): + auth_method = "key_pair" + project_size = 1_000 + reuse_connections = True diff --git a/tests/unit/test_private_keys.py b/tests/unit/test_private_keys.py new file mode 100644 index 000000000..59b8522d2 --- /dev/null +++ b/tests/unit/test_private_keys.py @@ -0,0 +1,61 @@ +import os +import tempfile +from typing import Generator + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +import pytest + +from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string + + +PASSPHRASE = "password1234" + + +def serialize(private_key: rsa.RSAPrivateKey) -> bytes: + return private_key.private_bytes( + serialization.Encoding.DER, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + + +@pytest.fixture(scope="session") +def private_key() -> rsa.RSAPrivateKey: + return rsa.generate_private_key(public_exponent=65537, key_size=2048) + + +@pytest.fixture(scope="session") +def private_key_string(private_key) -> str: + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), + ) + return private_key_bytes.decode("utf-8") + + +@pytest.fixture(scope="session") +def private_key_file(private_key) -> Generator[str, None, None]: + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), + ) + file = tempfile.NamedTemporaryFile() + file.write(private_key_bytes) + file.seek(0) + yield file.name + file.close() + + +def test_private_key_from_string_pem(private_key_string, private_key): + assert isinstance(private_key_string, str) + calculated_private_key = private_key_from_string(private_key_string, PASSPHRASE) + assert serialize(calculated_private_key) == serialize(private_key) + + +def test_private_key_from_file(private_key_file, private_key): + assert os.path.exists(private_key_file) + calculated_private_key = private_key_from_file(private_key_file, PASSPHRASE) + assert serialize(calculated_private_key) == serialize(private_key) diff --git a/tests/unit/test_relation_as_case_sensitive.py b/tests/unit/test_relation_as_case_sensitive.py new file mode 100644 index 000000000..f362d66b3 --- /dev/null +++ b/tests/unit/test_relation_as_case_sensitive.py @@ -0,0 +1,19 @@ +from dbt.adapters.snowflake.relation import SnowflakeRelation +from dbt.adapters.snowflake.relation_configs import SnowflakeQuotePolicy + + +def test_relation_as_case_sensitive_quoting_true(): + relation = SnowflakeRelation.create( + database="My_Db", + schema="My_ScHeMa", + identifier="My_TaBlE", + quote_policy=SnowflakeQuotePolicy(database=False, schema=True, identifier=False), + ) + + case_sensitive_relation = relation.as_case_sensitive() + case_sensitive_relation.render_limited() + + assert case_sensitive_relation.database == "MY_DB" + assert case_sensitive_relation.schema == "My_ScHeMa" + assert case_sensitive_relation.identifier == "MY_TABLE" + assert case_sensitive_relation.render() == 'MY_DB."My_ScHeMa".MY_TABLE' diff --git a/tests/unit/test_snowflake_adapter.py b/tests/unit/test_snowflake_adapter.py index f6a768da8..32e73eb45 100644 --- a/tests/unit/test_snowflake_adapter.py +++ b/tests/unit/test_snowflake_adapter.py @@ -290,13 +290,19 @@ def test_client_session_keep_alive_false_by_default(self): application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ), ] ) def test_client_session_keep_alive_true(self): - self.config.credentials = self.config.credentials.replace(client_session_keep_alive=True) + self.config.credentials = self.config.credentials.replace( + client_session_keep_alive=True, + # this gets defaulted via `__post_init__` when `client_session_keep_alive` comes in as `False` + # then when `replace` is called, `__post_init__` cannot set it back to `None` since it cannot + # tell the difference between set by user and set by `__post_init__` + reuse_connections=None, + ) self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") @@ -339,7 +345,7 @@ def test_client_has_query_tag(self): role=None, schema="public", user="test_user", - reuse_connections=None, + reuse_connections=True, warehouse="test_warehouse", private_key=None, application="dbt", @@ -379,7 +385,7 @@ def test_user_pass_authentication(self): application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -413,7 +419,7 @@ def test_authenticator_user_pass_authentication(self): client_store_temporary_credential=True, insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -443,7 +449,7 @@ def test_authenticator_externalbrowser_authentication(self): client_store_temporary_credential=True, insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -477,7 +483,7 @@ def test_authenticator_oauth_authentication(self): client_store_temporary_credential=True, insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -511,7 +517,7 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key): application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -545,7 +551,7 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -577,7 +583,7 @@ def test_authenticator_jwt_authentication(self): client_store_temporary_credential=True, insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -607,7 +613,7 @@ def test_query_tag(self): application="dbt", insecure_mode=False, session_parameters={"QUERY_TAG": "test_query_tag"}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -670,7 +676,7 @@ def test_authenticator_private_key_string_authentication(self, mock_get_private_ application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] ) @@ -706,7 +712,7 @@ def test_authenticator_private_key_string_authentication_no_passphrase( application="dbt", insecure_mode=False, session_parameters={}, - reuse_connections=None, + reuse_connections=True, ) ] )