diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index c4a0aa61a..3e7961ed1 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -56,7 +56,7 @@ def create_test_task_metadata( execution_mode: ExecutionMode, task_args: dict[str, Any], on_warning_callback: Callable[..., Any] | None = None, - model_name: str | None = None, + node: DbtNode | None = None, ) -> TaskMetadata: """ Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node. @@ -66,13 +66,18 @@ def create_test_task_metadata( :param task_args: Arguments to be used to instantiate an Airflow Task :param on_warning_callback: A callback function called on warnings with additional Context variables “test_names” and “test_results” of type List. - :param model_name: If the test relates to a specific model, the name of the model it relates to + :param node: If the test relates to a specific node, the node reference :returns: The metadata necessary to instantiate the source dbt node as an Airflow task. """ task_args = dict(task_args) task_args["on_warning_callback"] = on_warning_callback - if model_name is not None: - task_args["models"] = model_name + if node is not None: + if node.resource_type == DbtResourceType.MODEL: + task_args["models"] = node.name + elif node.resource_type == DbtResourceType.SOURCE: + task_args["select"] = f"source:{node.unique_id[len('source.'):]}" + else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT + task_args["select"] = node.name return TaskMetadata( id=test_task_name, operator_class=calculate_operator_class( @@ -112,6 +117,8 @@ def create_task_metadata( task_id = "run" else: task_id = f"{node.name}_{node.resource_type.value}" + if use_task_group is True: + task_id = node.resource_type.value task_metadata = TaskMetadata( id=task_id, @@ -163,7 +170,7 @@ def generate_task_or_group( "test", execution_mode, task_args=task_args, - model_name=node.name, + node=node, on_warning_callback=on_warning_callback, ) test_task = create_airflow_task(test_meta, dag, task_group=model_task_group) diff --git a/cosmos/constants.py b/cosmos/constants.py index 2ce9a09fe..cd59c8173 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -70,7 +70,7 @@ def _missing_value_(cls, value): # type: ignore DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values() - -TESTABLE_DBT_RESOURCES = { - DbtResourceType.MODEL -} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED) +# dbt test runs tests defined on models, sources, snapshots, and seeds. +# It expects that you have already created those resources through the appropriate commands. +# https://docs.getdbt.com/reference/commands/test +TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED} diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 57a3462f7..2eb93c613 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -5,6 +5,7 @@ import pytest from airflow import __version__ as airflow_version from airflow.models import DAG +from airflow.utils.task_group import TaskGroup from packaging import version from cosmos.airflow.graph import ( @@ -13,6 +14,7 @@ calculate_operator_class, create_task_metadata, create_test_task_metadata, + generate_task_or_group, ) from cosmos.config import ProfileConfig from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior @@ -101,6 +103,50 @@ def test_build_airflow_graph_with_after_each(): assert dag.leaves[0].task_id == "child_run" +@pytest.mark.parametrize( + "node_type,task_suffix", + [(DbtResourceType.MODEL, "run"), (DbtResourceType.SEED, "seed"), (DbtResourceType.SNAPSHOT, "snapshot")], +) +def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix): + """ + dbt test runs tests defined on models, sources, snapshots, and seeds. + It expects that you have already created those resources through the appropriate commands. + https://docs.getdbt.com/reference/commands/test + """ + with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag: + node = DbtNode( + name="dbt_node", + unique_id="dbt_node", + resource_type=node_type, + file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", + tags=["has_child"], + config={"materialized": "view"}, + depends_on=[], + has_test=True, + ) + output = generate_task_or_group( + dag=dag, + task_group=None, + node=node, + execution_mode=ExecutionMode.LOCAL, + task_args={ + "project_dir": SAMPLE_PROJ_PATH, + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + }, + test_behavior=TestBehavior.AFTER_EACH, + on_warning_callback=None, + ) + assert isinstance(output, TaskGroup) + assert list(output.children.keys()) == [f"dbt_node.{task_suffix}", "dbt_node.test"] + + @pytest.mark.skipif( version.parse(airflow_version) < version.parse("2.4"), reason="Airflow DAG did not have task_group_dict until the 2.4 release", @@ -259,7 +305,12 @@ def test_create_task_metadata_seed(caplog, use_task_group): args={}, use_task_group=use_task_group, ) - assert metadata.id == "my_seed_seed" + + if not use_task_group: + assert metadata.id == "my_seed_seed" + else: + assert metadata.id == "seed" + assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator" assert metadata.arguments == {"models": "my_seed"} @@ -280,14 +331,32 @@ def test_create_task_metadata_snapshot(caplog): assert metadata.arguments == {"models": "my_snapshot"} -def test_create_test_task_metadata(): +@pytest.mark.parametrize( + "node_type,node_unique_id,selector_key,selector_value", + [ + (DbtResourceType.MODEL, "node_name", "models", "node_name"), + (DbtResourceType.SEED, "node_name", "select", "node_name"), + (DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"), + (DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"), + ], +) +def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value): + sample_node = DbtNode( + name="node_name", + unique_id=node_unique_id, + resource_type=node_type, + depends_on=[], + file_path="", + tags=[], + config={}, + ) metadata = create_test_task_metadata( test_task_name="test_no_nulls", execution_mode=ExecutionMode.LOCAL, task_args={"task_arg": "value"}, on_warning_callback=True, - model_name="my_model", + node=sample_node, ) assert metadata.id == "test_no_nulls" assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator" - assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, "models": "my_model"} + assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value}