diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index aca115b4b..99ecf2948 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -43,7 +43,7 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.events.logging import AdapterLogger from dbt_common.events.functions import warn_or_error -from dbt.adapters.events.types import AdapterEventWarning +from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError from dbt_common.ui import line_wrap_message, warning_tag @@ -70,7 +70,7 @@ class SnowflakeAdapterResponse(AdapterResponse): @dataclass class SnowflakeCredentials(Credentials): account: str - user: str + user: Optional[str] = None warehouse: Optional[str] = None role: Optional[str] = None password: Optional[str] = None @@ -96,15 +96,29 @@ class SnowflakeCredentials(Credentials): reuse_connections: Optional[bool] = None def __post_init__(self): - if self.authenticator != "oauth" and ( - self.oauth_client_secret or self.oauth_client_id or self.token - ): + if self.authenticator != "oauth" and (self.oauth_client_secret or self.oauth_client_id): # the user probably forgot to set 'authenticator' like I keep doing warn_or_error( AdapterEventWarning( base_msg="Authenticator is not set to oauth, but an oauth-only parameter is set! Did you mean to set authenticator: oauth?" ) ) + + if self.authenticator not in ["oauth", "jwt"]: + if self.token: + warn_or_error( + AdapterEventWarning( + base_msg=( + "The token parameter was set, but the authenticator was " + "not set to 'oauth' or 'jwt'." + ) + ) + ) + + if not self.user: + # The user attribute is only optional if 'authenticator' is 'jwt' or 'oauth' + warn_or_error(AdapterEventError(base_msg="'user' is a required property.")) + self.account = self.account.replace("_", "-") @property @@ -146,6 +160,8 @@ def auth_args(self): # Pull all of the optional authentication args for the connector, # let connector handle the actual arg validation result = {} + if self.user: + result["user"] = self.user if self.password: result["password"] = self.password if self.host: @@ -180,6 +196,14 @@ def auth_args(self): ) result["token"] = token + + elif self.authenticator == "jwt": + # If authenticator is 'jwt', then the 'token' value should be used + # unmodified. We expose this as 'jwt' in the profile, but the value + # passed into the snowflake.connect method should still be 'oauth' + result["token"] = self.token + result["authenticator"] = "oauth" + # enable id token cache for linux result["client_store_temporary_credential"] = True # enable mfa token cache for linux @@ -346,7 +370,6 @@ def connect(): handle = snowflake.connector.connect( account=creds.account, - user=creds.user, database=creds.database, schema=creds.schema, warehouse=creds.warehouse, diff --git a/tests/unit/test_snowflake_adapter.py b/tests/unit/test_snowflake_adapter.py index ff92b9b65..f6a768da8 100644 --- a/tests/unit/test_snowflake_adapter.py +++ b/tests/unit/test_snowflake_adapter.py @@ -550,6 +550,38 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p ] ) + def test_authenticator_jwt_authentication(self): + self.config.credentials = self.config.credentials.replace( + authenticator="jwt", token="my-jwt-token", user=None + ) + self.adapter = SnowflakeAdapter(self.config, get_context("spawn")) + conn = self.adapter.connections.set_connection_name(name="new_connection_with_new_config") + + self.snowflake.assert_not_called() + conn.handle + self.snowflake.assert_has_calls( + [ + mock.call( + account="test-account", + autocommit=True, + client_session_keep_alive=False, + database="test_database", + role=None, + schema="public", + warehouse="test_warehouse", + authenticator="oauth", + token="my-jwt-token", + private_key=None, + application="dbt", + client_request_mfa_token=True, + client_store_temporary_credential=True, + insecure_mode=False, + session_parameters={}, + reuse_connections=None, + ) + ] + ) + def test_query_tag(self): self.config.credentials = self.config.credentials.replace( password="test_password", query_tag="test_query_tag"