diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py new file mode 100644 index 000000000..599a9c8e5 --- /dev/null +++ b/cosmos/profiles/snowflake/base.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.profiles.base import BaseProfileMapping + +DEFAULT_AWS_REGION = "us-west-2" + + +class SnowflakeBaseProfileMapping(BaseProfileMapping): + + @property + def profile(self) -> dict[str, Any | None]: + """Gets profile.""" + profile_vars = { + **self.mapped_params, + **self.profile_args, + } + return profile_vars + + def transform_account(self, account: str) -> str: + """Transform the account to the format . if it's not already.""" + region = self.conn.extra_dejson.get("region") + # + if region and region != DEFAULT_AWS_REGION and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 70722dd59..63a6c68d3 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -75,20 +75,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - "private_key": self.get_env_var_format("private_key"), - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key"] = self.get_env_var_format("private_key") + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index e217a6c22..6f35dad45 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -74,20 +74,6 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key_passphrase should always get set as env var - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 3fc6595c9..93c29793b 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): +class SnowflakeUserPasswordProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/password. https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup @@ -76,20 +76,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # password should always get set as env var - "password": self.get_env_var_format("password"), - } - - # remove any null values + profile_vars = super().profile + # password should always get set as env var + profile_vars["password"] = self.get_env_var_format("password") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c74194b7a..40a016af7 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -65,20 +65,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key should always get set as env var - "private_key": self.get_env_var_format("private_key"), - } - - # remove any null values + profile_vars = super().profile + # private_key should always get set as env var + profile_vars["private_key"] = self.get_env_var_format("private_key") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/tests/profiles/snowflake/test_snowflake_base.py b/tests/profiles/snowflake/test_snowflake_base.py new file mode 100644 index 000000000..ee8f6c6b3 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_base.py @@ -0,0 +1,19 @@ +from unittest.mock import patch + +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-west-2"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount" + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-east-1"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_non_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount.us-east-1"