Skip to content

Commit

Permalink
Add tests to sources, snapshots and seeds when using TestBehavior.AFT…
Browse files Browse the repository at this point in the history
…ER_EACH (#599)

Previously Cosmos would only create a task group when using
`TestBehavior.AFTER_EACH` for nodes of the type `DbtResourceType.MODEL`.
This change adds the same behavior to snapshots and seeds.

For this to work as expected with sources, we would need to create a
default operator to handle `DbtResourceType.SOURCE`, which is outside
the scope of the current ticket. Once this operator exists, sources will
also lead to creating a task group.

All the test selectors were tested successfully with dbt 1.6.

This screenshot illustrates the validation of this feature, with an
adapted version of Jaffle Shop:
<img width="1502" alt="Screenshot 2023-10-13 at 19 43 52"
src="https://github.com/astronomer/astronomer-cosmos/assets/272048/893ac557-1846-4ed1-b5e2-913a7bd38485">

The modifications that were done to jaffle_shop were:

Appended the following lines to
`dev/dags/dbt/jaffle_shop/models/schema.yml`:
```
seeds:
  - name: raw_customers
    description: Raw data from customers
    columns:
      - name: id
        tests:
          - unique
          - not_null

snapshots:
  - name: orders_snapshot
    description: Snapshot of orders
    columns:
      - name: orders_snapshot.order_id
        tests:
          - unique
          - not_null
```

And created the file
`dev/dags/dbt/jaffle_shop/snapshots/orders_snapshot.sql` with:
```
{% snapshot orders_snapshot %}

{{
    config(
      target_database='postgres',
      target_schema='public',
      unique_key='order_id',

      strategy='timestamp',
      updated_at='order_date',
    )
}}

select * from {{ ref('jaffle_shop', 'orders') }}

{% endsnapshot %}
```

Closes: #474
  • Loading branch information
tatiana authored Oct 18, 2023
1 parent a433f15 commit e09ac6d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 13 deletions.
17 changes: 12 additions & 5 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_test_task_metadata(
execution_mode: ExecutionMode,
task_args: dict[str, Any],
on_warning_callback: Callable[..., Any] | None = None,
model_name: str | None = None,
node: DbtNode | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand All @@ -66,13 +66,18 @@ def create_test_task_metadata(
:param task_args: Arguments to be used to instantiate an Airflow Task
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
:param model_name: If the test relates to a specific model, the name of the model it relates to
:param node: If the test relates to a specific node, the node reference
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
if model_name is not None:
task_args["models"] = model_name
if node is not None:
if node.resource_type == DbtResourceType.MODEL:
task_args["models"] = node.name
elif node.resource_type == DbtResourceType.SOURCE:
task_args["select"] = f"source:{node.unique_id[len('source.'):]}"
else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT
task_args["select"] = node.name
return TaskMetadata(
id=test_task_name,
operator_class=calculate_operator_class(
Expand Down Expand Up @@ -112,6 +117,8 @@ def create_task_metadata(
task_id = "run"
else:
task_id = f"{node.name}_{node.resource_type.value}"
if use_task_group is True:
task_id = node.resource_type.value

task_metadata = TaskMetadata(
id=task_id,
Expand Down Expand Up @@ -163,7 +170,7 @@ def generate_task_or_group(
"test",
execution_mode,
task_args=task_args,
model_name=node.name,
node=node,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
Expand Down
8 changes: 4 additions & 4 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _missing_value_(cls, value): # type: ignore

DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values()


TESTABLE_DBT_RESOURCES = {
DbtResourceType.MODEL
} # TODO: extend with DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED)
# dbt test runs tests defined on models, sources, snapshots, and seeds.
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}
77 changes: 73 additions & 4 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from airflow import __version__ as airflow_version
from airflow.models import DAG
from airflow.utils.task_group import TaskGroup
from packaging import version

from cosmos.airflow.graph import (
Expand All @@ -13,6 +14,7 @@
calculate_operator_class,
create_task_metadata,
create_test_task_metadata,
generate_task_or_group,
)
from cosmos.config import ProfileConfig
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
Expand Down Expand Up @@ -101,6 +103,50 @@ def test_build_airflow_graph_with_after_each():
assert dag.leaves[0].task_id == "child_run"


@pytest.mark.parametrize(
"node_type,task_suffix",
[(DbtResourceType.MODEL, "run"), (DbtResourceType.SEED, "seed"), (DbtResourceType.SNAPSHOT, "snapshot")],
)
def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix):
"""
dbt test runs tests defined on models, sources, snapshots, and seeds.
It expects that you have already created those resources through the appropriate commands.
https://docs.getdbt.com/reference/commands/test
"""
with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag:
node = DbtNode(
name="dbt_node",
unique_id="dbt_node",
resource_type=node_type,
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view"},
depends_on=[],
has_test=True,
)
output = generate_task_or_group(
dag=dag,
task_group=None,
node=node,
execution_mode=ExecutionMode.LOCAL,
task_args={
"project_dir": SAMPLE_PROJ_PATH,
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
},
test_behavior=TestBehavior.AFTER_EACH,
on_warning_callback=None,
)
assert isinstance(output, TaskGroup)
assert list(output.children.keys()) == [f"dbt_node.{task_suffix}", "dbt_node.test"]


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4"),
reason="Airflow DAG did not have task_group_dict until the 2.4 release",
Expand Down Expand Up @@ -259,7 +305,12 @@ def test_create_task_metadata_seed(caplog, use_task_group):
args={},
use_task_group=use_task_group,
)
assert metadata.id == "my_seed_seed"

if not use_task_group:
assert metadata.id == "my_seed_seed"
else:
assert metadata.id == "seed"

assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator"
assert metadata.arguments == {"models": "my_seed"}

Expand All @@ -280,14 +331,32 @@ def test_create_task_metadata_snapshot(caplog):
assert metadata.arguments == {"models": "my_snapshot"}


def test_create_test_task_metadata():
@pytest.mark.parametrize(
"node_type,node_unique_id,selector_key,selector_value",
[
(DbtResourceType.MODEL, "node_name", "models", "node_name"),
(DbtResourceType.SEED, "node_name", "select", "node_name"),
(DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"),
(DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"),
],
)
def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value):
sample_node = DbtNode(
name="node_name",
unique_id=node_unique_id,
resource_type=node_type,
depends_on=[],
file_path="",
tags=[],
config={},
)
metadata = create_test_task_metadata(
test_task_name="test_no_nulls",
execution_mode=ExecutionMode.LOCAL,
task_args={"task_arg": "value"},
on_warning_callback=True,
model_name="my_model",
node=sample_node,
)
assert metadata.id == "test_no_nulls"
assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator"
assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, "models": "my_model"}
assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value}

0 comments on commit e09ac6d

Please sign in to comment.