diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 3e3103266..4848c45c5 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -336,7 +336,7 @@ def build_airflow_graph( render_config: RenderConfig, task_group: TaskGroup | None = None, on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command -) -> None: +) -> dict[str, Union[TaskGroup, BaseOperator]]: """ Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory). @@ -359,11 +359,12 @@ def build_airflow_graph( :param task_group: Airflow Task Group instance :param on_warning_callback: A callback function called on warnings with additional Context variables “test_names” and “test_results” of type List. + :return: Dictionary mapping dbt nodes (node.unique_id to Airflow task) """ node_converters = render_config.node_converters or {} test_behavior = render_config.test_behavior source_rendering_behavior = render_config.source_rendering_behavior - tasks_map = {} + tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {} task_or_group: TaskGroup | BaseOperator for node_id, node in nodes.items(): @@ -408,6 +409,7 @@ def build_airflow_graph( create_airflow_task_dependencies(nodes, tasks_map) _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) + return tasks_map def create_airflow_task_dependencies( diff --git a/cosmos/converter.py b/cosmos/converter.py index debb5c0bb..5bf99cac8 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -292,7 +292,7 @@ def __init__( if execution_config.execution_mode == ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None: task_args["virtualenv_dir"] = execution_config.virtualenv_dir - build_airflow_graph( + self.tasks_map = build_airflow_graph( nodes=self.dbt_graph.filtered_nodes, dag=dag or (task_group and task_group.dag), task_group=task_group, diff --git a/dev/dags/example_tasks_map.py b/dev/dags/example_tasks_map.py new file mode 100644 index 000000000..8147ad740 --- /dev/null +++ b/dev/dags/example_tasks_map.py @@ -0,0 +1,55 @@ +""" +An example DAG that demonstrates how to walk over the dbt graph. It also shows how to use the mapping from +{dbt graph unique_id} -> {Airflow tasks/task groups}. +""" + +import os +from datetime import datetime +from pathlib import Path + +from airflow.operators.empty import EmptyOperator + +from cosmos import DbtDag, DbtResourceType, ProfileConfig, ProjectConfig +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +# [START example_tasks_map] +with DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=profile_config, + operator_args={ + "install_deps": True, # install any necessary dependencies before running any dbt command + "full_refresh": True, # used only in dbt commands that support this flag + }, + # normal dag parameters + schedule_interval="@daily", + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="customized_cosmos_dag", + default_args={"retries": 2}, +) as dag: + # Walk the dbt graph + for unique_id, dbt_node in dag.dbt_graph.filtered_nodes.items(): + # Filter by any dbt_node property you prefer. In this case, we are adding upstream tasks to source nodes. + if dbt_node.resource_type == DbtResourceType.SOURCE: + # Look up the corresponding Airflow task or task group in the DbtToAirflowConverter.tasks_map property. + task = dag.tasks_map[unique_id] + # Create a task upstream of this Airflow source task/task group. + upstream_task = EmptyOperator(task_id=f"upstream_of_{unique_id}") + upstream_task >> task +# [END example_tasks_map] diff --git a/docs/configuration/dag-customization.rst b/docs/configuration/dag-customization.rst new file mode 100644 index 000000000..2936f9495 --- /dev/null +++ b/docs/configuration/dag-customization.rst @@ -0,0 +1,38 @@ +.. _dag_customization: + +Post-rendering DAG customization +================ + +.. note:: + The DbtToAirflowConverter.tasks_map property is only available for cosmos >= 1.8.0 + +After Cosmos has rendered an Airflow DAG from a dbt project, you may want to add some extra Airflow tasks that interact +with the tasks created by Cosmos. This document explains how to do this. + +An example use case you can think of is implementing sensor tasks that wait for an external DAG task to complete before +running a source node task (or task group, if the source contains a test). + +Mapping from dbt nodes to Airflow tasks +---------------------- + +To interact with Airflow tasks created by Cosmos, +you can iterate over the dag.dbt_graph.filtered_nodes property like so: + +.. + This is an abbreviated copy of example_tasks_map.py, as GitHub does not render literalinclude blocks + +.. code-block:: python + + with DbtDag( + dag_id="customized_cosmos_dag", + # Other arguments omitted for brevity + ) as dag: + # Walk the dbt graph + for unique_id, dbt_node in dag.dbt_graph.filtered_nodes.items(): + # Filter by any dbt_node property you prefer. In this case, we are adding upstream tasks to source nodes. + if dbt_node.resource_type == DbtResourceType.SOURCE: + # Look up the corresponding Airflow task or task group in the DbtToAirflowConverter.tasks_map property. + task = dag.tasks_map[unique_id] + # Create a task upstream of this Airflow source task/task group. + upstream_task = EmptyOperator(task_id=f"upstream_of_{unique_id}") + upstream_task >> task diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index 6c47884e9..9001b4c2e 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -23,6 +23,7 @@ Cosmos offers a number of configuration options to customize its behavior. For m Selecting & Excluding Partial Parsing Source Nodes Rendering + Post-rendering DAG customization Operator Args Compiled SQL Logging diff --git a/tests/test_converter.py b/tests/test_converter.py index 9a8563212..9da31f00d 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -578,3 +578,35 @@ def test_converter_contains_dbt_graph(mock_load_dbt_graph, execution_mode, opera operator_args=operator_args, ) assert isinstance(converter.dbt_graph, DbtGraph) + + +@pytest.mark.parametrize( + "execution_mode,operator_args", + [ + (ExecutionMode.KUBERNETES, {}), + ], +) +@patch("cosmos.converter.DbtGraph.filtered_nodes", nodes) +@patch("cosmos.converter.DbtGraph.load") +def test_converter_contains_tasks_map(mock_load_dbt_graph, execution_mode, operator_args): + """ + This test validates that DbtToAirflowConverter contains and exposes a tasks map instance + """ + project_config = ProjectConfig(dbt_project_path=SAMPLE_DBT_PROJECT) + execution_config = ExecutionConfig(execution_mode=execution_mode) + render_config = RenderConfig(emit_datasets=True) + profile_config = ProfileConfig( + profile_name="my_profile_name", + target_name="my_target_name", + profiles_yml_filepath=SAMPLE_PROFILE_YML, + ) + converter = DbtToAirflowConverter( + dag=DAG("sample_dag", start_date=datetime(2024, 1, 1)), + nodes=nodes, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + operator_args=operator_args, + ) + assert isinstance(converter.tasks_map, dict)