diff --git a/dbt/adapters/synapse/__version__.py b/dbt/adapters/synapse/__version__.py index 6aaa73b..72126ce 100644 --- a/dbt/adapters/synapse/__version__.py +++ b/dbt/adapters/synapse/__version__.py @@ -1 +1 @@ -version = "1.8.0" +version = "1.8.1" diff --git a/dbt/adapters/synapse/synapse_connection_manager.py b/dbt/adapters/synapse/synapse_connection_manager.py index 23103d1..5f17289 100644 --- a/dbt/adapters/synapse/synapse_connection_manager.py +++ b/dbt/adapters/synapse/synapse_connection_manager.py @@ -1,6 +1,291 @@ -from dbt.adapters.fabric import FabricConnectionManager +import struct +import time +from itertools import chain, repeat +from typing import Callable, Dict, Mapping, Optional + +import pyodbc +from azure.core.credentials import AccessToken +from azure.identity import AzureCliCredential, DefaultAzureCredential, EnvironmentCredential +from dbt.adapters.contracts.connection import Connection, ConnectionState +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.fabric import FabricConnectionManager, __version__ +from dbt.adapters.fabric.fabric_credentials import FabricCredentials + +AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" +_TOKEN: Optional[AccessToken] = None +AZURE_AUTH_FUNCTION_TYPE = Callable[[FabricCredentials], AccessToken] + +logger = AdapterLogger("fabric") + + +def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: + """ + Convert bytes to a Microsoft windows byte string. + + Parameters + ---------- + value : bytes + The bytes. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0)))) + return struct.pack(" bytes: + """ + Convert an access token to a Microsoft windows byte string. + + Parameters + ---------- + token : AccessToken + The token. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + value = bytes(token.token, "UTF-8") + return convert_bytes_to_mswindows_byte_string(value) + + +def get_cli_access_token(credentials: FabricCredentials) -> AccessToken: + """ + Get an Azure access token using the CLI credentials + + First login with: + + ```bash + az login + ``` + + Parameters + ---------- + credentials: FabricConnectionManager + The credentials. + + Returns + ------- + out : AccessToken + Access token. + """ + _ = credentials + token = AzureCliCredential().get_token(AZURE_CREDENTIAL_SCOPE) + return token + + +def get_auto_access_token(credentials: FabricCredentials) -> AccessToken: + """ + Get an Azure access token automatically through azure-identity + + Parameters + ----------- + credentials: FabricCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + token = DefaultAzureCredential().get_token(AZURE_CREDENTIAL_SCOPE) + return token + + +def get_environment_access_token(credentials: FabricCredentials) -> AccessToken: + """ + Get an Azure access token by reading environment variables + + Parameters + ----------- + credentials: FabricCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + token = EnvironmentCredential().get_token(AZURE_CREDENTIAL_SCOPE) + return token + + +AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = { + "cli": get_cli_access_token, + "auto": get_auto_access_token, + "environment": get_environment_access_token, +} + + +def get_pyodbc_attrs_before(credentials: FabricCredentials) -> Dict: + """ + Get the pyodbc attrs before. + + Parameters + ---------- + credentials : FabricCredentials + Credentials. + + Returns + ------- + out : Dict + The pyodbc attrs before. + + Source + ------ + Authentication for SQL server with an access token: + https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token + """ + global _TOKEN + attrs_before: Dict + MAX_REMAINING_TIME = 300 + + authentication = str(credentials.authentication).lower() + if authentication in AZURE_AUTH_FUNCTIONS: + time_remaining = (_TOKEN.expires_on - time.time()) if _TOKEN else MAX_REMAINING_TIME + + if _TOKEN is None or (time_remaining < MAX_REMAINING_TIME): + azure_auth_function = AZURE_AUTH_FUNCTIONS[authentication] + _TOKEN = azure_auth_function(credentials) + + token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN) + sql_copt_ss_access_token = 1256 # see source in docstring + attrs_before = {sql_copt_ss_access_token: token_bytes} + else: + attrs_before = {} + + return attrs_before + + +def bool_to_connection_string_arg(key: str, value: bool) -> str: + """ + Convert a boolean to a connection string argument. + + Parameters + ---------- + key : str + The key to use in the connection string. + value : bool + The boolean to convert. + + Returns + ------- + out : str + The connection string argument. + """ + return f'{key}={"Yes" if value else "No"}' class SynapseConnectionManager(FabricConnectionManager): TYPE = "synapse" TOKEN = None + + @classmethod + def open(cls, connection: Connection) -> Connection: + if connection.state == ConnectionState.OPEN: + logger.debug("Connection is already open, skipping open.") + return connection + + credentials = cls.get_credentials(connection.credentials) + + con_str = [f"DRIVER={{{credentials.driver}}}"] + + if "\\" in credentials.host: + # If there is a backslash \ in the host name, the host is a + # SQL Server named instance. In this case then port number has to be omitted. + con_str.append(f"SERVER={credentials.host}") + else: + con_str.append(f"SERVER={credentials.host}") + + con_str.append(f"Database={credentials.database}") + + assert credentials.authentication is not None + + if "ActiveDirectory" in credentials.authentication: + con_str.append(f"Authentication={credentials.authentication}") + + if credentials.authentication == "ActiveDirectoryPassword": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") + if credentials.authentication == "ActiveDirectoryServicePrincipal": + con_str.append(f"UID={{{credentials.client_id}}}") + con_str.append(f"PWD={{{credentials.client_secret}}}") + elif credentials.authentication == "ActiveDirectoryInteractive": + con_str.append(f"UID={{{credentials.UID}}}") + + elif credentials.windows_login: + con_str.append("trusted_connection=Yes") + elif credentials.authentication == "sql": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") + + # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15 + assert credentials.encrypt is not None + assert credentials.trust_cert is not None + + con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) + con_str.append( + bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert) + ) + + plugin_version = __version__.version + application_name = f"dbt-{credentials.type}/{plugin_version}" + con_str.append(f"APP={application_name}") + + try: + if int(credentials.retries) > 0: + con_str.append(f"ConnectRetryCount={credentials.retries}") + + except Exception as e: + logger.debug( + "Retry count should be integer value. Skipping retries in the connection string.", + str(e), + ) + + con_str_concat = ";".join(con_str) + + index = [] + for i, elem in enumerate(con_str): + if "pwd=" in elem.lower(): + index.append(i) + + if len(index) != 0: + con_str[index[0]] = "PWD=***" + + con_str_display = ";".join(con_str) + + retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions + pyodbc.InternalError, # not used according to docs, but defined in PEP-249 + pyodbc.OperationalError, + ] + + if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: + # Temporary login/token errors fall into this category when using AAD + retryable_exceptions.append(pyodbc.InterfaceError) + + def connect(): + logger.debug(f"Using connection string: {con_str_display}") + + attrs_before = get_pyodbc_attrs_before(credentials) + handle = pyodbc.connect( + con_str_concat, + attrs_before=attrs_before, + autocommit=True, + timeout=credentials.login_timeout, + ) + handle.timeout = credentials.query_timeout + logger.debug(f"Connected to db: {credentials.database}") + return handle + + return cls.retry_connection( + connection, + connect=connect, + logger=logger, + retry_limit=credentials.retries, + retryable_exceptions=retryable_exceptions, + ) diff --git a/setup.py b/setup.py index 2f00c3d..742e918 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ def run(self): long_description_content_type="text/markdown", license="MIT", author=", ".join(authors_list), - url="https://github.com/dbt-msft/dbt-synapse", + url="https://github.com/microsoft/dbt-synapse", packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[dbt_fabric_requirement], diff --git a/tests/conftest.py b/tests/conftest.py index ee7b069..c9dc8d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ def pytest_addoption(parser): - parser.addoption("--profile", action="store", default="user_azure", type=str) + parser.addoption("--profile", action="store", default="user", type=str) @pytest.fixture(scope="class") @@ -54,6 +54,7 @@ def _profile_user(): "user": os.getenv("SYNAPSE_TEST_USER"), "pass": os.getenv("SYNAPSE_TEST_PASS"), "database": os.getenv("SYNAPSE_TEST_DWH_NAME"), + "authentication": "sql", }, } diff --git a/tests/functional/adapter/test_docs.py b/tests/functional/adapter/test_docs.py index 7ea26be..28b6073 100644 --- a/tests/functional/adapter/test_docs.py +++ b/tests/functional/adapter/test_docs.py @@ -19,7 +19,7 @@ class TestDocsGenerateSynapse(BaseDocsGenerate): def expected_catalog(self, project): return base_expected_catalog( project, - role="dbo", + role="dbttestuser", id_type="int", text_type="varchar", time_type="datetime2", @@ -34,7 +34,7 @@ class TestDocsGenReferencesSynapse(BaseDocsGenReferences): def expected_catalog(self, project): return expected_references_catalog( project, - role="dbo", + role="dbttestuser", id_type="int", text_type="varchar", time_type="datetime2", diff --git a/tests/functional/adapter/test_model_hooks.py b/tests/functional/adapter/test_model_hooks.py index d765705..67b3595 100644 --- a/tests/functional/adapter/test_model_hooks.py +++ b/tests/functional/adapter/test_model_hooks.py @@ -30,7 +30,7 @@ def check_hooks(self, state, project, host, count=1): assert ctx["target_schema"] == project.test_schema assert ctx["target_threads"] == 1 assert ctx["target_type"] == "synapse" - assert ctx["target_user"] == "None" + # assert ctx["target_user"] == "dbttestuser" assert ctx["target_pass"] == "" assert ( diff --git a/tests/functional/adapter/test_run_hooks.py b/tests/functional/adapter/test_run_hooks.py index cf2a32e..d803c52 100644 --- a/tests/functional/adapter/test_run_hooks.py +++ b/tests/functional/adapter/test_run_hooks.py @@ -64,7 +64,7 @@ def check_hooks(self, state, project, host): assert ctx["target_schema"] == project.test_schema assert ctx["target_threads"] == 1 assert ctx["target_type"] == "synapse" - assert ctx["target_user"] == "None" + # assert ctx["target_user"] == "None" assert ctx["target_pass"] == "" assert (