diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index e75b6c25e..47c7309ab 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -16,6 +16,7 @@ from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping +from .snowflake.user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping @@ -31,6 +32,7 @@ PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, + SnowflakeEncryptedPrivateKeyPemProfileMapping, SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, ExasolUserPasswordProfileMapping, diff --git a/cosmos/profiles/snowflake/__init__.py b/cosmos/profiles/snowflake/__init__.py index 450dc3772..26c3fb595 100644 --- a/cosmos/profiles/snowflake/__init__.py +++ b/cosmos/profiles/snowflake/__init__.py @@ -2,5 +2,10 @@ from .user_pass import SnowflakeUserPasswordProfileMapping from .user_privatekey import SnowflakePrivateKeyPemProfileMapping +from .user_encrypted_privatekey import SnowflakeEncryptedPrivateKeyPemProfileMapping -__all__ = ["SnowflakeUserPasswordProfileMapping", "SnowflakePrivateKeyPemProfileMapping"] +__all__ = [ + "SnowflakeUserPasswordProfileMapping", + "SnowflakePrivateKeyPemProfileMapping", + "SnowflakeEncryptedPrivateKeyPemProfileMapping", +] diff --git a/cosmos/profiles/snowflake/user_encrypted_privatekey.py b/cosmos/profiles/snowflake/user_encrypted_privatekey.py new file mode 100644 index 000000000..0623598be --- /dev/null +++ b/cosmos/profiles/snowflake/user_encrypted_privatekey.py @@ -0,0 +1,85 @@ +"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from ..base import BaseProfileMapping + +if TYPE_CHECKING: + from airflow.models import Connection + + +class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): + """ + Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. + https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication + https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html + """ + + airflow_connection_type: str = "snowflake" + dbt_profile_type: str = "snowflake" + is_community: bool = True + + required_fields = [ + "account", + "user", + "database", + "warehouse", + "schema", + "private_key_passphrase", + "private_key_path", + ] + secret_fields = [ + "private_key_passphrase", + ] + airflow_param_mapping = { + "account": "extra.account", + "user": "login", + "database": "extra.database", + "warehouse": "extra.warehouse", + "schema": "schema", + "role": "extra.role", + "private_key_passphrase": "password", + "private_key_path": "extra.private_key_file", + } + + @property + def conn(self) -> Connection: + """ + Snowflake can be odd because the fields used to be stored with keys in the format + 'extra__snowflake__account', but now are stored as 'account'. + + This standardizes the keys to be 'account', 'database', etc. + """ + conn = super().conn + + conn_dejson = conn.extra_dejson + + if conn_dejson.get("extra__snowflake__account"): + conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} + + conn.extra = json.dumps(conn_dejson) + + return conn + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile." + profile_vars = { + **self.mapped_params, + **self.profile_args, + # private_key_passphrase should always get set as env var + "private_key_passphrase": self.get_env_var_format("private_key_passphrase"), + } + + # remove any null values + return self.filter_null(profile_vars) + + def transform_account(self, account: str) -> str: + "Transform the account to the format . if it's not already." + region = self.conn.extra_dejson.get("region") + if region and region not in account: + account = f"{account}.{region}" + + return str(account) diff --git a/cosmos/profiles/snowflake/user_pass.py b/cosmos/profiles/snowflake/user_pass.py index 8be66414a..2e1025a2c 100644 --- a/cosmos/profiles/snowflake/user_pass.py +++ b/cosmos/profiles/snowflake/user_pass.py @@ -41,6 +41,13 @@ class SnowflakeUserPasswordProfileMapping(BaseProfileMapping): "role": "extra.role", } + def can_claim_connection(self) -> bool: + # Make sure this isn't a private key path credential + result = super().can_claim_connection() + if result and self.conn.extra_dejson.get("private_key_file") is not None: + return False + return result + @property def conn(self) -> Connection: """ diff --git a/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py new file mode 100644 index 000000000..b61b85094 --- /dev/null +++ b/tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py @@ -0,0 +1,216 @@ +"Tests for the Snowflake user/private key profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.snowflake import ( + SnowflakeEncryptedPrivateKeyPemProfileMapping, +) + + +@pytest.fixture() +def mock_snowflake_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_snowflake_pk_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "account": "my_account", + "region": "my_region", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_file": "path/to/private_key.p8", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Snowflake profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "snowflake", + "login": "my_user", + "schema": "my_database", + "password": "secret", + "extra": json.dumps( + { + "account": "my_account", + "database": "my_database", + "warehouse": "my_warehouse", + "private_key_file": "path/to/private_key.p8", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( + conn, + ) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the account + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the database + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # test when we're missing the warehouse + conn = Connection(**potential_values) # type: ignore + conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' + print("testing with", conn.extra) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) + + +def test_profile_args( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": mock_snowflake_conn.extra_dejson.get("database"), + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_args_overrides( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + profile_args={"database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "database": "my_db_override", + } + + mock_account = mock_snowflake_conn.extra_dejson.get("account") + mock_region = mock_snowflake_conn.extra_dejson.get("region") + + assert profile_mapping.profile == { + "type": mock_snowflake_conn.conn_type, + "user": mock_snowflake_conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), + "schema": mock_snowflake_conn.schema, + "account": f"{mock_account}.{mock_region}", + "database": "my_db_override", + "warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), + } + + +def test_profile_env_vars( + mock_snowflake_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_snowflake_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, + } + + +def test_old_snowflake_format() -> None: + """ + Tests that the old format still works. + """ + conn = Connection( + conn_id="my_snowflake_connection", + conn_type="snowflake", + login="my_user", + schema="my_schema", + password="secret", + extra=json.dumps( + { + "extra__snowflake__account": "my_account", + "extra__snowflake__database": "my_database", + "extra__snowflake__warehouse": "my_warehouse", + "extra__snowflake__private_key_file": "path/to/private_key.p8", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) + assert profile_mapping.profile == { + "type": conn.conn_type, + "user": conn.login, + "private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", + "private_key_path": conn.extra_dejson.get("private_key_file"), + "schema": conn.schema, + "account": conn.extra_dejson.get("account"), + "database": conn.extra_dejson.get("database"), + "warehouse": conn.extra_dejson.get("warehouse"), + }