From 22b20f1557d97d9076c9bdcad8b0f9c9f7b0a4df Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 19 Dec 2024 09:29:51 +0000 Subject: [PATCH] Fix Snowflake Profile mapping when using AWS default region (#1406) When using a Cosmos Snowflake Profile mapping using a Snowflake account set in the AWS default region, Cosmos would fail if the default region was specified in the Airflow connection. The dbt docs state: > For AWS accounts in the US West default region, you can use abc123 (without any other segments). For some AWS accounts you will have to append the region and/or cloud platform. For example, abc123.eu-west-1 or abc123.eu-west-2.aws. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#account Although it seems that defining the default region would be optional, a Cosmos user reported facing 404 and seeing a dbt error message when attempting to use `SnowflakeUserPasswordProfileMapping` with an Airflow Snowflake connection that defined the region `us-west-2`. ![snowflake-404](https://github.com/user-attachments/assets/c1884fff-1cad-4c57-b2f3-11a4f44b085b) We solved the issue by removing the region `us-west-2` from the connection. Since this restriction only applies to AWS and this Snowflake region only exists to AWS, this change seems safe: ![Screenshot 2024-12-18 at 18 45 31](https://github.com/user-attachments/assets/ff2f8a0b-578b-4a62-9fc3-258a43148775) --- cosmos/profiles/snowflake/base.py | 28 +++++++++++++++++++ .../user_encrypted_privatekey_env_variable.py | 23 ++++----------- .../user_encrypted_privatekey_file.py | 22 +++------------ cosmos/profiles/snowflake/user_pass.py | 23 ++++----------- cosmos/profiles/snowflake/user_privatekey.py | 23 ++++----------- .../profiles/snowflake/test_snowflake_base.py | 19 +++++++++++++ 6 files changed, 66 insertions(+), 72 deletions(-) create mode 100644 cosmos/profiles/snowflake/base.py create mode 100644 tests/profiles/snowflake/test_snowflake_base.py diff --git a/cosmos/profiles/snowflake/base.py b/cosmos/profiles/snowflake/base.py new file mode 100644 index 000000000..599a9c8e5 --- /dev/null +++ b/cosmos/profiles/snowflake/base.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.profiles.base import BaseProfileMapping + +DEFAULT_AWS_REGION = "us-west-2" + + +class SnowflakeBaseProfileMapping(BaseProfileMapping): + + @property + def profile(self) -> dict[str, Any | None]: + """Gets profile.""" + profile_vars = { + **self.mapped_params, + **self.profile_args, + } + return profile_vars + + def transform_account(self, account: str) -> str: + """Transform the account to the format . if it's not already.""" + region = self.conn.extra_dejson.get("region") + # + if region and region != DEFAULT_AWS_REGION and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py index 70722dd59..63a6c68d3 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_env_variable.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -75,20 +75,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - "private_key": self.get_env_var_format("private_key"), - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key"] = self.get_env_var_format("private_key") + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py index e217a6c22..6f35dad45 100644 --- a/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey_file.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping): +class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -74,20 +74,6 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key_passphrase should always get set as env var - "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), - } - - # remove any null values + profile_vars = super().profile + profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 3fc6595c9..93c29793b 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): +class SnowflakeUserPasswordProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/password. https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup @@ -76,20 +76,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # password should always get set as env var - "password": self.get_env_var_format("password"), - } - - # remove any null values + profile_vars = super().profile + # password should always get set as env var + profile_vars["password"] = self.get_env_var_format("password") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/cosmos/profiles/snowflake/user_privatekey.py b/cosmos/profiles/snowflake/user_privatekey.py index c74194b7a..40a016af7 100644 --- a/cosmos/profiles/snowflake/user_privatekey.py +++ b/cosmos/profiles/snowflake/user_privatekey.py @@ -5,13 +5,13 @@ import json from typing import TYPE_CHECKING, Any -from ..base import BaseProfileMapping +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping if TYPE_CHECKING: from airflow.models import Connection -class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping): +class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping): """ Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication @@ -65,20 +65,7 @@ def conn(self) -> Connection: @property def profile(self) -> dict[str, Any | None]: """Gets profile.""" - profile_vars = { - **self.mapped_params, - **self.profile_args, - # private_key should always get set as env var - "private_key": self.get_env_var_format("private_key"), - } - - # remove any null values + profile_vars = super().profile + # private_key should always get set as env var + profile_vars["private_key"] = self.get_env_var_format("private_key") return self.filter_null(profile_vars) - - def transform_account(self, account: str) -> str: - """Transform the account to the format . if it's not already.""" - region = self.conn.extra_dejson.get("region") - if region and region not in account: - account = f"{account}.{region}" - - return str(account) diff --git a/tests/profiles/snowflake/test_snowflake_base.py b/tests/profiles/snowflake/test_snowflake_base.py new file mode 100644 index 000000000..ee8f6c6b3 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_base.py @@ -0,0 +1,19 @@ +from unittest.mock import patch + +from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-west-2"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount" + + +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-east-1"}) +@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn") +def test_non_default_region(mock_conn): + profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn") + response = profile_mapping.transform_account("myaccount") + assert response == "myaccount.us-east-1"