Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve run times when using key pair auth by caching the private key #1110

Merged
merged 10 commits into from
Jul 11, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240710-172345.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Improve run times when using key pair auth by caching the private key
time: 2024-07-10T17:23:45.046905-04:00
custom:
Author: mikealfare aranke
Issue: "1082"
98 changes: 66 additions & 32 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import base64
import datetime
import os
import sys

if sys.version_info < (3, 9):
from functools import lru_cache

cache = lru_cache(maxsize=None)
else:
from functools import cache

import pytz
import re
Expand All @@ -13,6 +21,7 @@

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
import requests
import snowflake.connector
import snowflake.connector.constants
Expand Down Expand Up @@ -63,6 +72,58 @@
}


@cache
Copy link
Contributor

@colin-rogers-dbt colin-rogers-dbt Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we move these to their own module? Just want to keep connections.py from getting bigger if we don't need to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually want to move them into dbt-adapters long term. I thought about creating a separate module, but thought it would be unnecessary if we were going to move it later. I can add that here though if you think it's worth it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth it as it should make it cleaner to move later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved over only the two methods that are generic. I left the one that's Snowflake-specific (converting the key to DER format).

Along the same lines of making connections.py smaller, I actually think the Credentials class should be in its own module (probably named auth, even though that's what I moved these methods to). That's a separate PR, but it shouldn't take long to do if we want to do that.

def private_key_from_string(
private_key_string: str, passphrase: Optional[str] = None
) -> RSAPrivateKey:

if passphrase:
encoded_passphrase = passphrase.encode()
else:
encoded_passphrase = None

if private_key_string.startswith("-"):
return serialization.load_pem_private_key(
data=bytes(private_key_string, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)
return serialization.load_der_private_key(
data=base64.b64decode(private_key_string),
password=encoded_passphrase,
backend=default_backend(),
)


@cache
def private_key_from_file(
private_key_path: str, passphrase: Optional[str] = None
) -> RSAPrivateKey:

if passphrase:
encoded_passphrase = passphrase.encode()
else:
encoded_passphrase = None

with open(private_key_path, "rb") as file:
private_key_bytes = file.read()

return serialization.load_pem_private_key(
data=private_key_bytes,
password=encoded_passphrase,
backend=default_backend(),
)


@cache
def snowflake_private_key(private_key: RSAPrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)


@dataclass
class SnowflakeAdapterResponse(AdapterResponse):
query_id: str = ""
Expand Down Expand Up @@ -273,44 +334,17 @@ def _get_access_token(self) -> str:
)
return result_json["access_token"]

def _get_private_key(self):
def _get_private_key(self) -> Optional[bytes]:
"""Get Snowflake private key by path, from a Base64 encoded DER bytestring or None."""
if self.private_key and self.private_key_path:
raise DbtConfigError("Cannot specify both `private_key` and `private_key_path`")

if self.private_key_passphrase:
encoded_passphrase = self.private_key_passphrase.encode()
else:
encoded_passphrase = None

if self.private_key:
if self.private_key.startswith("-"):
p_key = serialization.load_pem_private_key(
data=bytes(self.private_key, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)

else:
p_key = serialization.load_der_private_key(
data=base64.b64decode(self.private_key),
password=encoded_passphrase,
backend=default_backend(),
)

elif self.private_key:
private_key = private_key_from_string(self.private_key, self.private_key_passphrase)
elif self.private_key_path:
with open(self.private_key_path, "rb") as key:
p_key = serialization.load_pem_private_key(
key.read(), password=encoded_passphrase, backend=default_backend()
)
private_key = private_key_from_file(self.private_key_path, self.private_key_passphrase)
else:
return None

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return snowflake_private_key(private_key)


class SnowflakeConnectionManager(SQLConnectionManager):
Expand Down
27 changes: 27 additions & 0 deletions tests/functional/auth_tests/test_key_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from dbt.tests.util import run_dbt
import pytest


class TestKeyPairAuth:
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture(scope="class", autouse=True)
def dbt_profile_target(self):
return {
"type": "snowflake",
"threads": 4,
"account": os.getenv("SNOWFLAKE_TEST_ACCOUNT"),
"user": os.getenv("SNOWFLAKE_TEST_USER"),
"private_key": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY"),
"private_key_passphrase": os.getenv("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE"),
"database": os.getenv("SNOWFLAKE_TEST_DATABASE"),
"warehouse": os.getenv("SNOWFLAKE_TEST_WAREHOUSE"),
"authenticator": "oauth",
}

@pytest.fixture(scope="class")
def models(self):
return {"model.sql": "select 1 as id"}

def test_snowflake_basic(self, project):
run_dbt()
61 changes: 61 additions & 0 deletions tests/unit/test_private_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import tempfile
from typing import Generator

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
import pytest

from dbt.adapters.snowflake.connections import private_key_from_file, private_key_from_string


PASSPHRASE = "password1234"


def serialize(private_key: rsa.RSAPrivateKey) -> bytes:
return private_key.private_bytes(
serialization.Encoding.DER,
serialization.PrivateFormat.PKCS8,
serialization.NoEncryption(),
)


@pytest.fixture(scope="session")
def private_key() -> rsa.RSAPrivateKey:
return rsa.generate_private_key(public_exponent=65537, key_size=2048)


@pytest.fixture(scope="session")
def private_key_string(private_key) -> str:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()),
)
return private_key_bytes.decode("utf-8")


@pytest.fixture(scope="session")
def private_key_file(private_key) -> Generator[str, None, None]:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(PASSPHRASE.encode()),
)
file = tempfile.NamedTemporaryFile()
file.write(private_key_bytes)
file.seek(0)
yield file.name
file.close()


def test_private_key_from_string_pem(private_key_string, private_key):
assert isinstance(private_key_string, str)
calculated_private_key = private_key_from_string(private_key_string, PASSPHRASE)
assert serialize(calculated_private_key) == serialize(private_key)


def test_private_key_from_file(private_key_file, private_key):
assert os.path.exists(private_key_file)
calculated_private_key = private_key_from_file(private_key_file, PASSPHRASE)
assert serialize(calculated_private_key) == serialize(private_key)
Loading