-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ProfileMapping for Snowflake encrypted private key path (#608)
Add support for Snowflake Airflow Connections with encrypted private key paths via `extra.private_key_file` and `password` configuration. Closes: #607
- Loading branch information
1 parent
d829a04
commit b7381c8
Showing
5 changed files
with
316 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"Maps Airflow Snowflake connections to dbt profiles if they use a user/private key." | ||
from __future__ import annotations | ||
|
||
import json | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from ..base import BaseProfileMapping | ||
|
||
if TYPE_CHECKING: | ||
from airflow.models import Connection | ||
|
||
|
||
class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping): | ||
""" | ||
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key. | ||
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication | ||
https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html | ||
""" | ||
|
||
airflow_connection_type: str = "snowflake" | ||
dbt_profile_type: str = "snowflake" | ||
is_community: bool = True | ||
|
||
required_fields = [ | ||
"account", | ||
"user", | ||
"database", | ||
"warehouse", | ||
"schema", | ||
"private_key_passphrase", | ||
"private_key_path", | ||
] | ||
secret_fields = [ | ||
"private_key_passphrase", | ||
] | ||
airflow_param_mapping = { | ||
"account": "extra.account", | ||
"user": "login", | ||
"database": "extra.database", | ||
"warehouse": "extra.warehouse", | ||
"schema": "schema", | ||
"role": "extra.role", | ||
"private_key_passphrase": "password", | ||
"private_key_path": "extra.private_key_file", | ||
} | ||
|
||
@property | ||
def conn(self) -> Connection: | ||
""" | ||
Snowflake can be odd because the fields used to be stored with keys in the format | ||
'extra__snowflake__account', but now are stored as 'account'. | ||
This standardizes the keys to be 'account', 'database', etc. | ||
""" | ||
conn = super().conn | ||
|
||
conn_dejson = conn.extra_dejson | ||
|
||
if conn_dejson.get("extra__snowflake__account"): | ||
conn_dejson = {key.replace("extra__snowflake__", ""): value for key, value in conn_dejson.items()} | ||
|
||
conn.extra = json.dumps(conn_dejson) | ||
|
||
return conn | ||
|
||
@property | ||
def profile(self) -> dict[str, Any | None]: | ||
"Gets profile." | ||
profile_vars = { | ||
**self.mapped_params, | ||
**self.profile_args, | ||
# private_key_passphrase should always get set as env var | ||
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"), | ||
} | ||
|
||
# remove any null values | ||
return self.filter_null(profile_vars) | ||
|
||
def transform_account(self, account: str) -> str: | ||
"Transform the account to the format <account>.<region> if it's not already." | ||
region = self.conn.extra_dejson.get("region") | ||
if region and region not in account: | ||
account = f"{account}.{region}" | ||
|
||
return str(account) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
216 changes: 216 additions & 0 deletions
216
tests/profiles/snowflake/test_snowflake_user_encrypted_privatekey.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
"Tests for the Snowflake user/private key profile." | ||
|
||
import json | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from airflow.models.connection import Connection | ||
|
||
from cosmos.profiles import get_automatic_profile_mapping | ||
from cosmos.profiles.snowflake import ( | ||
SnowflakeEncryptedPrivateKeyPemProfileMapping, | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def mock_snowflake_conn(): # type: ignore | ||
""" | ||
Sets the connection as an environment variable. | ||
""" | ||
conn = Connection( | ||
conn_id="my_snowflake_pk_connection", | ||
conn_type="snowflake", | ||
login="my_user", | ||
schema="my_schema", | ||
password="secret", | ||
extra=json.dumps( | ||
{ | ||
"account": "my_account", | ||
"region": "my_region", | ||
"database": "my_database", | ||
"warehouse": "my_warehouse", | ||
"private_key_file": "path/to/private_key.p8", | ||
} | ||
), | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
yield conn | ||
|
||
|
||
def test_connection_claiming() -> None: | ||
""" | ||
Tests that the Snowflake profile mapping claims the correct connection type. | ||
""" | ||
potential_values = { | ||
"conn_type": "snowflake", | ||
"login": "my_user", | ||
"schema": "my_database", | ||
"password": "secret", | ||
"extra": json.dumps( | ||
{ | ||
"account": "my_account", | ||
"database": "my_database", | ||
"warehouse": "my_warehouse", | ||
"private_key_file": "path/to/private_key.p8", | ||
} | ||
), | ||
} | ||
|
||
# if we're missing any of the values, it shouldn't claim | ||
for key in potential_values: | ||
values = potential_values.copy() | ||
del values[key] | ||
conn = Connection(**values) # type: ignore | ||
|
||
print("testing with", values) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping( | ||
conn, | ||
) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# test when we're missing the account | ||
conn = Connection(**potential_values) # type: ignore | ||
conn.extra = '{"database": "my_database", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' | ||
print("testing with", conn.extra) | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# test when we're missing the database | ||
conn = Connection(**potential_values) # type: ignore | ||
conn.extra = '{"account": "my_account", "warehouse": "my_warehouse", "private_key_content": "my_private_key"}' | ||
print("testing with", conn.extra) | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# test when we're missing the warehouse | ||
conn = Connection(**potential_values) # type: ignore | ||
conn.extra = '{"account": "my_account", "database": "my_database", "private_key_content": "my_private_key"}' | ||
print("testing with", conn.extra) | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# if we have them all, it should claim | ||
conn = Connection(**potential_values) # type: ignore | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) | ||
assert profile_mapping.can_claim_connection() | ||
|
||
|
||
def test_profile_mapping_selected( | ||
mock_snowflake_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the correct profile mapping is selected. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_snowflake_conn.conn_id, | ||
) | ||
assert isinstance(profile_mapping, SnowflakeEncryptedPrivateKeyPemProfileMapping) | ||
|
||
|
||
def test_profile_args( | ||
mock_snowflake_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the profile values get set correctly. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_snowflake_conn.conn_id, | ||
) | ||
|
||
mock_account = mock_snowflake_conn.extra_dejson.get("account") | ||
mock_region = mock_snowflake_conn.extra_dejson.get("region") | ||
|
||
assert profile_mapping.profile == { | ||
"type": mock_snowflake_conn.conn_type, | ||
"user": mock_snowflake_conn.login, | ||
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", | ||
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), | ||
"schema": mock_snowflake_conn.schema, | ||
"account": f"{mock_account}.{mock_region}", | ||
"database": mock_snowflake_conn.extra_dejson.get("database"), | ||
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), | ||
} | ||
|
||
|
||
def test_profile_args_overrides( | ||
mock_snowflake_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that you can override the profile values. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_snowflake_conn.conn_id, | ||
profile_args={"database": "my_db_override"}, | ||
) | ||
assert profile_mapping.profile_args == { | ||
"database": "my_db_override", | ||
} | ||
|
||
mock_account = mock_snowflake_conn.extra_dejson.get("account") | ||
mock_region = mock_snowflake_conn.extra_dejson.get("region") | ||
|
||
assert profile_mapping.profile == { | ||
"type": mock_snowflake_conn.conn_type, | ||
"user": mock_snowflake_conn.login, | ||
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", | ||
"private_key_path": mock_snowflake_conn.extra_dejson.get("private_key_file"), | ||
"schema": mock_snowflake_conn.schema, | ||
"account": f"{mock_account}.{mock_region}", | ||
"database": "my_db_override", | ||
"warehouse": mock_snowflake_conn.extra_dejson.get("warehouse"), | ||
} | ||
|
||
|
||
def test_profile_env_vars( | ||
mock_snowflake_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the environment variables get set correctly. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_snowflake_conn.conn_id, | ||
) | ||
assert profile_mapping.env_vars == { | ||
"COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE": mock_snowflake_conn.password, | ||
} | ||
|
||
|
||
def test_old_snowflake_format() -> None: | ||
""" | ||
Tests that the old format still works. | ||
""" | ||
conn = Connection( | ||
conn_id="my_snowflake_connection", | ||
conn_type="snowflake", | ||
login="my_user", | ||
schema="my_schema", | ||
password="secret", | ||
extra=json.dumps( | ||
{ | ||
"extra__snowflake__account": "my_account", | ||
"extra__snowflake__database": "my_database", | ||
"extra__snowflake__warehouse": "my_warehouse", | ||
"extra__snowflake__private_key_file": "path/to/private_key.p8", | ||
} | ||
), | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = SnowflakeEncryptedPrivateKeyPemProfileMapping(conn) | ||
assert profile_mapping.profile == { | ||
"type": conn.conn_type, | ||
"user": conn.login, | ||
"private_key_passphrase": "{{ env_var('COSMOS_CONN_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE') }}", | ||
"private_key_path": conn.extra_dejson.get("private_key_file"), | ||
"schema": conn.schema, | ||
"account": conn.extra_dejson.get("account"), | ||
"database": conn.extra_dejson.get("database"), | ||
"warehouse": conn.extra_dejson.get("warehouse"), | ||
} |