Skip to content

Commit

Permalink
Save tasks map as DbtToAirflowConverter property (#1362)
Browse files Browse the repository at this point in the history
Currently, if you want to modify a DAG after it has been rendered,
you have to walk through the dag.dbt_graph, then puzzle the task IDs
and task group IDs together by reverse-engineering your task rendering
strategy.

This is cumbersome and error-prone, hence it makes sense to expose the
mapping from DbtNode to Airflow Task ID as a DAG property. This allows
you to walk the DBT graph while directly accessing any corresponding
Airflow tasks, which makes e.g. adding Airflow sensors upstream of all
source tasks much easier.

Co-authored-by: hheemskerk <[email protected]>
  • Loading branch information
internetcoffeephone and hheemskerk authored Dec 18, 2024
1 parent 82d476b commit 6d4a239
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 3 deletions.
6 changes: 4 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def build_airflow_graph(
render_config: RenderConfig,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
) -> None:
) -> dict[str, Union[TaskGroup, BaseOperator]]:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand All @@ -359,11 +359,12 @@ def build_airflow_graph(
:param task_group: Airflow Task Group instance
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
:return: Dictionary mapping dbt nodes (node.unique_id to Airflow task)
"""
node_converters = render_config.node_converters or {}
test_behavior = render_config.test_behavior
source_rendering_behavior = render_config.source_rendering_behavior
tasks_map = {}
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator

for node_id, node in nodes.items():
Expand Down Expand Up @@ -408,6 +409,7 @@ def build_airflow_graph(

create_airflow_task_dependencies(nodes, tasks_map)
_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)
return tasks_map


def create_airflow_task_dependencies(
Expand Down
2 changes: 1 addition & 1 deletion cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(
if execution_config.execution_mode == ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
task_args["virtualenv_dir"] = execution_config.virtualenv_dir

build_airflow_graph(
self.tasks_map = build_airflow_graph(
nodes=self.dbt_graph.filtered_nodes,
dag=dag or (task_group and task_group.dag),
task_group=task_group,
Expand Down
55 changes: 55 additions & 0 deletions dev/dags/example_tasks_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
An example DAG that demonstrates how to walk over the dbt graph. It also shows how to use the mapping from
{dbt graph unique_id} -> {Airflow tasks/task groups}.
"""

import os
from datetime import datetime
from pathlib import Path

from airflow.operators.empty import EmptyOperator

from cosmos import DbtDag, DbtResourceType, ProfileConfig, ProjectConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

# [START example_tasks_map]
with DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "jaffle_shop",
),
profile_config=profile_config,
operator_args={
"install_deps": True, # install any necessary dependencies before running any dbt command
"full_refresh": True, # used only in dbt commands that support this flag
},
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
catchup=False,
dag_id="customized_cosmos_dag",
default_args={"retries": 2},
) as dag:
# Walk the dbt graph
for unique_id, dbt_node in dag.dbt_graph.filtered_nodes.items():
# Filter by any dbt_node property you prefer. In this case, we are adding upstream tasks to source nodes.
if dbt_node.resource_type == DbtResourceType.SOURCE:
# Look up the corresponding Airflow task or task group in the DbtToAirflowConverter.tasks_map property.
task = dag.tasks_map[unique_id]
# Create a task upstream of this Airflow source task/task group.
upstream_task = EmptyOperator(task_id=f"upstream_of_{unique_id}")
upstream_task >> task
# [END example_tasks_map]
38 changes: 38 additions & 0 deletions docs/configuration/dag-customization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
.. _dag_customization:

Post-rendering DAG customization
================

Check warning on line 4 in docs/configuration/dag-customization.rst

View workflow job for this annotation

GitHub Actions / pages

Title underline too short.

.. note::
The DbtToAirflowConverter.tasks_map property is only available for cosmos >= 1.8.0

After Cosmos has rendered an Airflow DAG from a dbt project, you may want to add some extra Airflow tasks that interact
with the tasks created by Cosmos. This document explains how to do this.

An example use case you can think of is implementing sensor tasks that wait for an external DAG task to complete before
running a source node task (or task group, if the source contains a test).

Mapping from dbt nodes to Airflow tasks
----------------------

Check warning on line 16 in docs/configuration/dag-customization.rst

View workflow job for this annotation

GitHub Actions / pages

Title underline too short.

To interact with Airflow tasks created by Cosmos,
you can iterate over the dag.dbt_graph.filtered_nodes property like so:

..
This is an abbreviated copy of example_tasks_map.py, as GitHub does not render literalinclude blocks
.. code-block:: python
with DbtDag(
dag_id="customized_cosmos_dag",
# Other arguments omitted for brevity
) as dag:
# Walk the dbt graph
for unique_id, dbt_node in dag.dbt_graph.filtered_nodes.items():
# Filter by any dbt_node property you prefer. In this case, we are adding upstream tasks to source nodes.
if dbt_node.resource_type == DbtResourceType.SOURCE:
# Look up the corresponding Airflow task or task group in the DbtToAirflowConverter.tasks_map property.
task = dag.tasks_map[unique_id]
# Create a task upstream of this Airflow source task/task group.
upstream_task = EmptyOperator(task_id=f"upstream_of_{unique_id}")
upstream_task >> task
1 change: 1 addition & 0 deletions docs/configuration/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Cosmos offers a number of configuration options to customize its behavior. For m
Selecting & Excluding <selecting-excluding>
Partial Parsing <partial-parsing>
Source Nodes Rendering <source-nodes-rendering>
Post-rendering DAG customization <dag-customization>
Operator Args <operator-args>
Compiled SQL <compiled-sql>
Logging <logging>
Expand Down
32 changes: 32 additions & 0 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,35 @@ def test_converter_contains_dbt_graph(mock_load_dbt_graph, execution_mode, opera
operator_args=operator_args,
)
assert isinstance(converter.dbt_graph, DbtGraph)


@pytest.mark.parametrize(
"execution_mode,operator_args",
[
(ExecutionMode.KUBERNETES, {}),
],
)
@patch("cosmos.converter.DbtGraph.filtered_nodes", nodes)
@patch("cosmos.converter.DbtGraph.load")
def test_converter_contains_tasks_map(mock_load_dbt_graph, execution_mode, operator_args):
"""
This test validates that DbtToAirflowConverter contains and exposes a tasks map instance
"""
project_config = ProjectConfig(dbt_project_path=SAMPLE_DBT_PROJECT)
execution_config = ExecutionConfig(execution_mode=execution_mode)
render_config = RenderConfig(emit_datasets=True)
profile_config = ProfileConfig(
profile_name="my_profile_name",
target_name="my_target_name",
profiles_yml_filepath=SAMPLE_PROFILE_YML,
)
converter = DbtToAirflowConverter(
dag=DAG("sample_dag", start_date=datetime(2024, 1, 1)),
nodes=nodes,
project_config=project_config,
profile_config=profile_config,
execution_config=execution_config,
render_config=render_config,
operator_args=operator_args,
)
assert isinstance(converter.tasks_map, dict)

0 comments on commit 6d4a239

Please sign in to comment.