From ca00b23ad0021b9cfd2626a49cb34a66ae2cb215 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 17:19:23 -0400 Subject: [PATCH 1/8] cache _get_private_key --- dbt/adapters/snowflake/connections.py | 114 +++++++++++++++++--------- 1 file changed, 77 insertions(+), 37 deletions(-) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 6e9a5aaba..4db6b462a 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -1,6 +1,14 @@ import base64 import datetime import os +import sys + +if sys.version_info < (3, 9): + from functools import lru_cache + + cache = lru_cache(maxsize=None) +else: + from functools import cache import pytz import re @@ -13,6 +21,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes import requests import snowflake.connector import snowflake.connector.constants @@ -63,6 +72,73 @@ } +@cache +def _private_key( + private_key: Optional[str] = None, + private_key_path: Optional[str] = None, + private_key_passphrase: Optional[str] = None, +) -> Optional[bytes]: + """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" + + if private_key and private_key_path: + raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") + elif private_key: + p_key = _private_key_from_string(private_key, private_key_passphrase) + elif private_key_path: + p_key = _private_key_from_file(private_key_path, private_key_passphrase) + else: + return None + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +@cache +def _private_key_from_string( + private_key: str, private_key_passphrase: Optional[str] = None +) -> PrivateKeyTypes: + + if private_key_passphrase: + encoded_passphrase = private_key_passphrase.encode() + else: + encoded_passphrase = None + + if private_key.startswith("-"): + return serialization.load_pem_private_key( + data=bytes(private_key, "utf-8"), + password=encoded_passphrase, + backend=default_backend(), + ) + return serialization.load_der_private_key( + data=base64.b64decode(private_key), + password=encoded_passphrase, + backend=default_backend(), + ) + + +@cache +def _private_key_from_file( + private_key_path: str, private_key_passphrase: Optional[str] = None +) -> Optional[bytes]: + + if private_key_passphrase: + encoded_passphrase = private_key_passphrase.encode() + else: + encoded_passphrase = None + + with open(private_key_path, "rb") as key: + p_key_bytes = key.read() + + return serialization.load_pem_private_key( + data=p_key_bytes, + password=encoded_passphrase, + backend=default_backend(), + ) + + @dataclass class SnowflakeAdapterResponse(AdapterResponse): query_id: str = "" @@ -274,43 +350,7 @@ def _get_access_token(self) -> str: return result_json["access_token"] def _get_private_key(self): - """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" - if self.private_key and self.private_key_path: - raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") - - if self.private_key_passphrase: - encoded_passphrase = self.private_key_passphrase.encode() - else: - encoded_passphrase = None - - if self.private_key: - if self.private_key.startswith("-"): - p_key = serialization.load_pem_private_key( - data=bytes(self.private_key, "utf-8"), - password=encoded_passphrase, - backend=default_backend(), - ) - - else: - p_key = serialization.load_der_private_key( - data=base64.b64decode(self.private_key), - password=encoded_passphrase, - backend=default_backend(), - ) - - elif self.private_key_path: - with open(self.private_key_path, "rb") as key: - p_key = serialization.load_pem_private_key( - key.read(), password=encoded_passphrase, backend=default_backend() - ) - else: - return None - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) + return _private_key(self.private_key, self.private_key_path, self.private_key_passphrase) class SnowflakeConnectionManager(SQLConnectionManager): From bdc55d93becc2a92f949b37fbfad8725e4f4744d Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 17:22:12 -0400 Subject: [PATCH 2/8] add type hints --- dbt/adapters/snowflake/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 4db6b462a..ad77633a6 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -349,7 +349,7 @@ def _get_access_token(self) -> str: ) return result_json["access_token"] - def _get_private_key(self): + def _get_private_key(self) -> Optional[bytes]: return _private_key(self.private_key, self.private_key_path, self.private_key_passphrase) From b2271bf4afee41dd683fb2d15ee9cc051eddafaa Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 17:23:54 -0400 Subject: [PATCH 3/8] changelog --- .changes/unreleased/Features-20240710-172345.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20240710-172345.yaml diff --git a/.changes/unreleased/Features-20240710-172345.yaml b/.changes/unreleased/Features-20240710-172345.yaml new file mode 100644 index 000000000..e68f63812 --- /dev/null +++ b/.changes/unreleased/Features-20240710-172345.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Improve run times when using key pair auth by caching the private key +time: 2024-07-10T17:23:45.046905-04:00 +custom: + Author: mikealfare aranke + Issue: "1082" From 0be34d6e9ad314b6cca33c09f859744d4032a2c5 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 23:36:25 -0400 Subject: [PATCH 4/8] update caching functions and add unit tests --- dbt/adapters/snowflake/connections.py | 78 +++++++++++++-------------- tests/unit/test_private_keys.py | 61 +++++++++++++++++++++ 2 files changed, 97 insertions(+), 42 deletions(-) create mode 100644 tests/unit/test_private_keys.py diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index ad77633a6..a40e5d776 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -21,7 +21,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey import requests import snowflake.connector import snowflake.connector.constants @@ -73,72 +73,57 @@ @cache -def _private_key( - private_key: Optional[str] = None, - private_key_path: Optional[str] = None, - private_key_passphrase: Optional[str] = None, -) -> Optional[bytes]: - """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" - - if private_key and private_key_path: - raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") - elif private_key: - p_key = _private_key_from_string(private_key, private_key_passphrase) - elif private_key_path: - p_key = _private_key_from_file(private_key_path, private_key_passphrase) - else: - return None - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - -@cache -def _private_key_from_string( - private_key: str, private_key_passphrase: Optional[str] = None -) -> PrivateKeyTypes: +def private_key_from_string( + private_key_string: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: - if private_key_passphrase: - encoded_passphrase = private_key_passphrase.encode() + if passphrase: + encoded_passphrase = passphrase.encode() else: encoded_passphrase = None - if private_key.startswith("-"): + if private_key_string.startswith("-"): return serialization.load_pem_private_key( - data=bytes(private_key, "utf-8"), + data=bytes(private_key_string, "utf-8"), password=encoded_passphrase, backend=default_backend(), ) return serialization.load_der_private_key( - data=base64.b64decode(private_key), + data=base64.b64decode(private_key_string), password=encoded_passphrase, backend=default_backend(), ) @cache -def _private_key_from_file( - private_key_path: str, private_key_passphrase: Optional[str] = None -) -> Optional[bytes]: +def private_key_from_file( + private_key_path: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: - if private_key_passphrase: - encoded_passphrase = private_key_passphrase.encode() + if passphrase: + encoded_passphrase = passphrase.encode() else: encoded_passphrase = None - with open(private_key_path, "rb") as key: - p_key_bytes = key.read() + with open(private_key_path, "rb") as file: + private_key_bytes = file.read() return serialization.load_pem_private_key( - data=p_key_bytes, + data=private_key_bytes, password=encoded_passphrase, backend=default_backend(), ) +@cache +def snowflake_private_key(private_key: RSAPrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + @dataclass class SnowflakeAdapterResponse(AdapterResponse): query_id: str = "" @@ -350,7 +335,16 @@ def _get_access_token(self) -> str: return result_json["access_token"] def _get_private_key(self) -> Optional[bytes]: - return _private_key(self.private_key, self.private_key_path, self.private_key_passphrase) + """Get Snowflake private key by path, from a Base64 encoded DER bytestring or None.""" + if self.private_key and self.private_key_path: + raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`") + elif self.private_key: + private_key = private_key_from_string(self.private_key, self.private_key_passphrase) + elif self.private_key_path: + private_key = private_key_from_file(self.private_key_path, self.private_key_passphrase) + else: + return None + return snowflake_private_key(private_key) class SnowflakeConnectionManager(SQLConnectionManager): diff --git a/tests/unit/test_private_keys.py b/tests/unit/test_private_keys.py new file mode 100644 index 000000000..8a5d4c30b --- /dev/null +++ b/tests/unit/test_private_keys.py @@ -0,0 +1,61 @@ +import os +import tempfile +from typing import Generator + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +import pytest + +from dbt.adapters.snowflake.connections import private_key_from_file, private_key_from_string + + +PASSPHRASE = "password1234" + + +def serialize(private_key: rsa.RSAPrivateKey) -> bytes: + return private_key.private_bytes( + serialization.Encoding.DER, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + + +@pytest.fixture(scope="session") +def private_key() -> rsa.RSAPrivateKey: + return rsa.generate_private_key(public_exponent=65537, key_size=2048) + + +@pytest.fixture(scope="session") +def private_key_string(private_key) -> str: + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), + ) + return private_key_bytes.decode("utf-8") + + +@pytest.fixture(scope="session") +def private_key_file(private_key) -> Generator[str, None, None]: + private_key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()), + ) + file = tempfile.NamedTemporaryFile() + file.write(private_key_bytes) + file.seek(0) + yield file.name + file.close() + + +def test_private_key_from_string_pem(private_key_string, private_key): + assert isinstance(private_key_string, str) + calculated_private_key = private_key_from_string(private_key_string, PASSPHRASE) + assert serialize(calculated_private_key) == serialize(private_key) + + +def test_private_key_from_file(private_key_file, private_key): + assert os.path.exists(private_key_file) + calculated_private_key = private_key_from_file(private_key_file, PASSPHRASE) + assert serialize(calculated_private_key) == serialize(private_key) From e67867901a29ea54b1519514996818c781e8fd5f Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 23:36:38 -0400 Subject: [PATCH 5/8] add an integration test for key pair auth method --- tests/functional/oauth/test_key_pair.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/functional/oauth/test_key_pair.py diff --git a/tests/functional/oauth/test_key_pair.py b/tests/functional/oauth/test_key_pair.py new file mode 100644 index 000000000..1d218c543 --- /dev/null +++ b/tests/functional/oauth/test_key_pair.py @@ -0,0 +1,27 @@ +import os + +from dbt.tests.util import run_dbt +import pytest + + +class TestKeyPairAuth: + @pytest.fixture(scope="class", autouse=True) + def dbt_profile_target(self): + return { + "type": "snowflake", + "threads": 4, + "account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"), + "user": os.getenv("SNOWFLAKE_TEST_USER"), + "private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"), + "private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"), + "database": os.getenv("SNOWFLAKE_TEST_DATABASE"), + "warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), + "authenticator": "oauth", + } + + @pytest.fixture(scope="class") + def models(self): + return {"model.sql": "select 1 as id"} + + def test_snowflake_basic(self, project): + run_dbt() From 0ea9564cc817b6cffa67b7577dd47bbe5530cdcf Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 23:38:29 -0400 Subject: [PATCH 6/8] generalize oauth test suite to auth_tests --- tests/functional/{oauth => auth_tests}/test_jwt.py | 0 tests/functional/{oauth => auth_tests}/test_key_pair.py | 0 tests/functional/{oauth => auth_tests}/test_oauth.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/functional/{oauth => auth_tests}/test_jwt.py (100%) rename tests/functional/{oauth => auth_tests}/test_key_pair.py (100%) rename tests/functional/{oauth => auth_tests}/test_oauth.py (100%) diff --git a/tests/functional/oauth/test_jwt.py b/tests/functional/auth_tests/test_jwt.py similarity index 100% rename from tests/functional/oauth/test_jwt.py rename to tests/functional/auth_tests/test_jwt.py diff --git a/tests/functional/oauth/test_key_pair.py b/tests/functional/auth_tests/test_key_pair.py similarity index 100% rename from tests/functional/oauth/test_key_pair.py rename to tests/functional/auth_tests/test_key_pair.py diff --git a/tests/functional/oauth/test_oauth.py b/tests/functional/auth_tests/test_oauth.py similarity index 100% rename from tests/functional/oauth/test_oauth.py rename to tests/functional/auth_tests/test_oauth.py From 1440cc8397f5135e1bb642c3d584b16603d25451 Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Wed, 10 Jul 2024 23:53:12 -0400 Subject: [PATCH 7/8] typos --- tests/functional/auth_tests/test_key_pair.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/functional/auth_tests/test_key_pair.py b/tests/functional/auth_tests/test_key_pair.py index 1d218c543..6d3254f33 100644 --- a/tests/functional/auth_tests/test_key_pair.py +++ b/tests/functional/auth_tests/test_key_pair.py @@ -16,12 +16,11 @@ def dbt_profile_target(self): "private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"), "database": os.getenv("SNOWFLAKE_TEST_DATABASE"), "warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"), - "authenticator": "oauth", } @pytest.fixture(scope="class") def models(self): - return {"model.sql": "select 1 as id"} + return {"my_model.sql": "select 1 as id"} - def test_snowflake_basic(self, project): + def test_connection(self, project): run_dbt() From 0c1163be81501e6a23f627ce8bae5b1f204e55fd Mon Sep 17 00:00:00 2001 From: Mike Alfare Date: Thu, 11 Jul 2024 16:14:59 -0400 Subject: [PATCH 8/8] move generic private key methods into their own module --- dbt/adapters/snowflake/auth.py | 57 +++++++++++++++++++++++++++ dbt/adapters/snowflake/connections.py | 46 +-------------------- tests/unit/test_private_keys.py | 2 +- 3 files changed, 60 insertions(+), 45 deletions(-) create mode 100644 dbt/adapters/snowflake/auth.py diff --git a/dbt/adapters/snowflake/auth.py b/dbt/adapters/snowflake/auth.py new file mode 100644 index 000000000..e914b6f3d --- /dev/null +++ b/dbt/adapters/snowflake/auth.py @@ -0,0 +1,57 @@ +import base64 +import sys +from typing import Optional + +if sys.version_info < (3, 9): + from functools import lru_cache + + cache = lru_cache(maxsize=None) +else: + from functools import cache + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + +@cache +def private_key_from_string( + private_key_string: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: + + if passphrase: + encoded_passphrase = passphrase.encode() + else: + encoded_passphrase = None + + if private_key_string.startswith("-"): + return serialization.load_pem_private_key( + data=bytes(private_key_string, "utf-8"), + password=encoded_passphrase, + backend=default_backend(), + ) + return serialization.load_der_private_key( + data=base64.b64decode(private_key_string), + password=encoded_passphrase, + backend=default_backend(), + ) + + +@cache +def private_key_from_file( + private_key_path: str, passphrase: Optional[str] = None +) -> RSAPrivateKey: + + if passphrase: + encoded_passphrase = passphrase.encode() + else: + encoded_passphrase = None + + with open(private_key_path, "rb") as file: + private_key_bytes = file.read() + + return serialization.load_pem_private_key( + data=private_key_bytes, + password=encoded_passphrase, + backend=default_backend(), + ) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index a40e5d776..a9e83cd2e 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -19,7 +19,6 @@ from typing import Optional, Tuple, Union, Any, List, Iterable, TYPE_CHECKING -from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey import requests @@ -53,6 +52,8 @@ from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError from dbt_common.ui import line_wrap_message, warning_tag +from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string + if TYPE_CHECKING: import agate @@ -72,49 +73,6 @@ } -@cache -def private_key_from_string( - private_key_string: str, passphrase: Optional[str] = None -) -> RSAPrivateKey: - - if passphrase: - encoded_passphrase = passphrase.encode() - else: - encoded_passphrase = None - - if private_key_string.startswith("-"): - return serialization.load_pem_private_key( - data=bytes(private_key_string, "utf-8"), - password=encoded_passphrase, - backend=default_backend(), - ) - return serialization.load_der_private_key( - data=base64.b64decode(private_key_string), - password=encoded_passphrase, - backend=default_backend(), - ) - - -@cache -def private_key_from_file( - private_key_path: str, passphrase: Optional[str] = None -) -> RSAPrivateKey: - - if passphrase: - encoded_passphrase = passphrase.encode() - else: - encoded_passphrase = None - - with open(private_key_path, "rb") as file: - private_key_bytes = file.read() - - return serialization.load_pem_private_key( - data=private_key_bytes, - password=encoded_passphrase, - backend=default_backend(), - ) - - @cache def snowflake_private_key(private_key: RSAPrivateKey) -> bytes: return private_key.private_bytes( diff --git a/tests/unit/test_private_keys.py b/tests/unit/test_private_keys.py index 8a5d4c30b..59b8522d2 100644 --- a/tests/unit/test_private_keys.py +++ b/tests/unit/test_private_keys.py @@ -6,7 +6,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa import pytest -from dbt.adapters.snowflake.connections import private_key_from_file, private_key_from_string +from dbt.adapters.snowflake.auth import private_key_from_file, private_key_from_string PASSPHRASE = "password1234"