diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 00b934881..06601e59a 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -13,6 +13,7 @@ from .databricks.oauth import DatabricksOauthProfileMapping from .databricks.token import DatabricksTokenProfileMapping from .exasol.user_pass import ExasolUserPasswordProfileMapping +from .oracle.user_pass import OracleUserPasswordProfileMapping from .postgres.user_pass import PostgresUserPasswordProfileMapping from .redshift.user_pass import RedshiftUserPasswordProfileMapping from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping @@ -34,6 +35,7 @@ GoogleCloudOauthProfileMapping, DatabricksTokenProfileMapping, DatabricksOauthProfileMapping, + OracleUserPasswordProfileMapping, PostgresUserPasswordProfileMapping, RedshiftUserPasswordProfileMapping, SnowflakeUserPasswordProfileMapping, @@ -77,6 +79,7 @@ def get_automatic_profile_mapping( "DatabricksTokenProfileMapping", "DatabricksOauthProfileMapping", "DbtProfileConfigVars", + "OracleUserPasswordProfileMapping", "PostgresUserPasswordProfileMapping", "RedshiftUserPasswordProfileMapping", "SnowflakeUserPasswordProfileMapping", diff --git a/cosmos/profiles/oracle/__init__.py b/cosmos/profiles/oracle/__init__.py new file mode 100644 index 000000000..221f5f3a3 --- /dev/null +++ b/cosmos/profiles/oracle/__init__.py @@ -0,0 +1,5 @@ +"""Oracle Airflow connection -> dbt profile mappings""" + +from .user_pass import OracleUserPasswordProfileMapping + +__all__ = ["OracleUserPasswordProfileMapping"] diff --git a/cosmos/profiles/oracle/user_pass.py b/cosmos/profiles/oracle/user_pass.py new file mode 100644 index 000000000..f230848c5 --- /dev/null +++ b/cosmos/profiles/oracle/user_pass.py @@ -0,0 +1,89 @@ +"""Maps Airflow Oracle connections using user + password authentication to dbt profiles.""" + +from __future__ import annotations + +import re +from typing import Any + +from ..base import BaseProfileMapping + + +class OracleUserPasswordProfileMapping(BaseProfileMapping): + """ + Maps Airflow Oracle connections using user + password authentication to dbt profiles. + https://docs.getdbt.com/reference/warehouse-setups/oracle-setup + https://airflow.apache.org/docs/apache-airflow-providers-oracle/stable/connections/oracle.html + """ + + airflow_connection_type: str = "oracle" + dbt_profile_type: str = "oracle" + is_community: bool = True + + required_fields = [ + "user", + "password", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "host": "host", + "port": "port", + "service": "extra.service_name", + "user": "login", + "password": "password", + "database": "extra.service_name", + "connection_string": "extra.dsn", + } + + @property + def env_vars(self) -> dict[str, str]: + """Set oracle thick mode.""" + env_vars = super().env_vars + if self._get_airflow_conn_field("extra.thick_mode"): + env_vars["ORA_PYTHON_DRIVER_TYPE"] = "thick" + return env_vars + + @property + def profile(self) -> dict[str, Any | None]: + """Gets profile. The password is stored in an environment variable.""" + profile = { + "protocol": "tcp", + "port": 1521, + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + if "schema" not in profile and "user" in profile: + proxy = re.search(r"\[([^]]+)\]", profile["user"]) + if proxy: + profile["schema"] = proxy.group(1) + else: + profile["schema"] = profile["user"] + if "schema" in self.profile_args: + profile["schema"] = self.profile_args["schema"] + + return self.filter_null(profile) + + @property + def mock_profile(self) -> dict[str, Any | None]: + """Gets mock profile. Defaults port to 1521.""" + profile_dict = { + "protocol": "tcp", + "port": 1521, + **super().mock_profile, + } + + if "schema" not in profile_dict and "user" in profile_dict: + proxy = re.search(r"\[([^]]+)\]", profile_dict["user"]) + if proxy: + profile_dict["schema"] = proxy.group(1) + else: + profile_dict["schema"] = profile_dict["user"] + + user_defined_schema = self.profile_args.get("schema") + if user_defined_schema: + profile_dict["schema"] = user_defined_schema + return profile_dict diff --git a/pyproject.toml b/pyproject.toml index cad6c3896..ee5503510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dbt-all = [ # See: https://github.com/astronomer/astronomer-cosmos/issues/1379 "dbt-databricks!=1.9.0", "dbt-exasol", + "dbt-oracle", "dbt-postgres", "dbt-redshift", "dbt-snowflake", @@ -61,6 +62,7 @@ dbt-bigquery = ["dbt-bigquery"] dbt-clickhouse = ["dbt-clickhouse"] dbt-databricks = ["dbt-databricks"] dbt-exasol = ["dbt-exasol"] +dbt-oracle = ["dbt-oracle"] dbt-postgres = ["dbt-postgres"] dbt-redshift = ["dbt-redshift"] dbt-snowflake = ["dbt-snowflake"] diff --git a/tests/profiles/oracle/test_oracle_user_pass.py b/tests/profiles/oracle/test_oracle_user_pass.py new file mode 100644 index 000000000..7f1258470 --- /dev/null +++ b/tests/profiles/oracle/test_oracle_user_pass.py @@ -0,0 +1,254 @@ +"""Tests for the Oracle profile.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.oracle.user_pass import OracleUserPasswordProfileMapping + + +@pytest.fixture() +def mock_oracle_conn(): # type: ignore + """ + Sets the Oracle connection as an environment variable. + """ + conn = Connection( + conn_id="my_oracle_connection", + conn_type="oracle", + host="my_host", + login="my_user", + password="my_password", + port=1521, + extra='{"service_name": "my_service"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_oracle_conn_custom_port(): # type: ignore + """ + Sets the Oracle connection with a custom port as an environment variable. + """ + conn = Connection( + conn_id="my_oracle_connection", + conn_type="oracle", + host="my_host", + login="my_user", + password="my_password", + port=1600, + extra='{"service_name": "my_service"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the Oracle profile mapping claims the correct connection type. + """ + potential_values = { + "conn_type": "oracle", + "login": "my_user", + "password": "my_password", + } + + # if we're missing any of the required values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = OracleUserPasswordProfileMapping(conn, {"schema": "my_schema"}) + assert not profile_mapping.can_claim_connection() + + # if we have all the required values, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = OracleUserPasswordProfileMapping(conn, {"schema": "my_schema"}) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_oracle_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_oracle_conn.conn_id, + {"schema": "my_schema"}, + ) + assert isinstance(profile_mapping, OracleUserPasswordProfileMapping) + + +def test_profile_mapping_keeps_custom_port(mock_oracle_conn_custom_port: Connection) -> None: + profile = OracleUserPasswordProfileMapping(mock_oracle_conn_custom_port.conn_id, {"schema": "my_schema"}) + assert profile.profile["port"] == 1600 + + +def test_profile_args( + mock_oracle_conn: Connection, +) -> None: + """ + Tests that the profile values are set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_oracle_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + } + + assert profile_mapping.profile == { + "type": mock_oracle_conn.conn_type, + "host": mock_oracle_conn.host, + "user": mock_oracle_conn.login, + "password": "{{ env_var('COSMOS_CONN_ORACLE_PASSWORD') }}", + "port": mock_oracle_conn.port, + "database": "my_service", + "service": "my_service", + "schema": "my_schema", + "protocol": "tcp", + } + + +def test_profile_args_overrides( + mock_oracle_conn: Connection, +) -> None: + """ + Tests that profile values can be overridden. + """ + profile_mapping = get_automatic_profile_mapping( + mock_oracle_conn.conn_id, + profile_args={ + "schema": "my_schema_override", + "database": "my_database_override", + "service": "my_service_override", + }, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema_override", + "database": "my_database_override", + "service": "my_service_override", + } + + assert profile_mapping.profile == { + "type": mock_oracle_conn.conn_type, + "host": mock_oracle_conn.host, + "user": mock_oracle_conn.login, + "password": "{{ env_var('COSMOS_CONN_ORACLE_PASSWORD') }}", + "port": mock_oracle_conn.port, + "database": "my_database_override", + "service": "my_service_override", + "schema": "my_schema_override", + "protocol": "tcp", + } + + +def test_profile_env_vars( + mock_oracle_conn: Connection, +) -> None: + """ + Tests that environment variables are set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_oracle_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_ORACLE_PASSWORD": mock_oracle_conn.password, + } + + +def test_env_vars_thick_mode(mock_oracle_conn: Connection) -> None: + """ + Tests that `env_vars` includes `ORA_PYTHON_DRIVER_TYPE` when `extra.thick_mode` is enabled. + """ + mock_oracle_conn.extra = '{"service_name": "my_service", "thick_mode": true}' + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"}) + assert profile_mapping.env_vars == { + "COSMOS_CONN_ORACLE_PASSWORD": mock_oracle_conn.password, + "ORA_PYTHON_DRIVER_TYPE": "thick", + } + + +def test_profile_filter_null(mock_oracle_conn: Connection) -> None: + """ + Tests that `profile` filters out null values. + """ + mock_oracle_conn.extra = '{"service_name": "my_service"}' + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": None}) + profile = profile_mapping.profile + assert "schema" not in profile + + +def test_mock_profile(mock_oracle_conn: Connection) -> None: + """ + Tests that `mock_profile` sets default port and schema correctly. + """ + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"}) + mock_profile = profile_mapping.mock_profile + assert mock_profile["port"] == 1521 + assert mock_profile["schema"] == "my_schema" + assert mock_profile["protocol"] == "tcp" + + +def test_invalid_connection_type() -> None: + """ + Tests that the profile mapping does not claim a non-oracle connection type. + """ + conn = Connection(conn_id="invalid_conn", conn_type="postgres", login="my_user", password="my_password") + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = OracleUserPasswordProfileMapping(conn, {}) + assert not profile_mapping.can_claim_connection() + + +def test_airflow_param_mapping(mock_oracle_conn: Connection) -> None: + """ + Tests that `airflow_param_mapping` correctly maps Airflow fields to dbt profile fields. + """ + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {"schema": "my_schema"}) + mapped_params = profile_mapping.mapped_params + + assert mapped_params["host"] == mock_oracle_conn.host + assert mapped_params["port"] == mock_oracle_conn.port + assert mapped_params["service"] == "my_service" + assert mapped_params["user"] == mock_oracle_conn.login + assert mapped_params["password"] == mock_oracle_conn.password + + +def test_profile_schema_extraction_with_proxy(mock_oracle_conn: Connection) -> None: + """ + Tests that the `schema` is extracted correctly from the `user` field + when a proxy schema is provided in square brackets. + """ + mock_oracle_conn.login = "my_user[proxy_schema]" + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {}) + + assert profile_mapping.profile["schema"] == "proxy_schema" + + +def test_profile_schema_defaults_to_user(mock_oracle_conn: Connection) -> None: + """ + Tests that the `schema` defaults to the `user` field when no proxy schema is provided. + """ + mock_oracle_conn.login = "my_user" + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {}) + + assert profile_mapping.profile["schema"] == "my_user" + + +def test_mock_profile_schema_extraction_with_proxy_gets_mock_value(mock_oracle_conn: Connection) -> None: + mock_oracle_conn.login = "my_user[proxy_schema]" + profile_mapping = OracleUserPasswordProfileMapping(mock_oracle_conn.conn_id, {}) + + mock_profile = profile_mapping.mock_profile + + assert mock_profile["schema"] == "mock_value"