Skip to content

Commit

Permalink
Experimental BQ support to run dbt models with `ExecutionMode.AIRFLOW…
Browse files Browse the repository at this point in the history
…_ASYNC` (#1230)

Enable BQ users to run dbt models (`full_refresh`) asynchronously. This
releases the Airflow worker node from waiting while the transformation
(I/O) happens in the dataware house, increasing the overall Airflow task
throughput (more information:
https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html).
As part of this change, we introduce the capability of not using the dbt
command to run actual SQL transformations. This also avoids creating
subprocesses in the worker node (`ExecutionMode.LOCAL` with
`InvocationMode. SUBPROCESS` and `ExecutionMode.VIRTUALENV`) or the
overhead of creating a Kubernetes Pod to execute the actual dbt command
(`ExecutionMode.KUBERNETES`). This can avoid issues related to memory
and CPU usage.

This PR takes advantage of an already implemented async operator in the
Airflow repo by extending it in the Cosmos async operator. It also
utilizes the pre-compiled SQL generated as part of the PR
#1224. It downloads
the generated SQL from a remote location (S3/GCS), which allows us to
decouple from dbt during task execution.

## Details

- Expose `get_profile_type` on ProfileConfig: This aids in database
selection
- ~Add `async_op_args`: A high-level parameter to forward arguments to
the upstream operator (Airflow operator). (This may change in this PR
itself)~ The async operator params are process as kwargs in the
operator_args parameter
- Implement `DbtRunAirflowAsyncOperator`: This initializes the Airflow
Operator, retrieves the SQL query at task runtime from a remote
location, modifies the query as needed, and triggers the upstream
execute method.

## Limitations

- This feature only works when using Airflow 2.8 and above
- The async execution only works for BigQuery
- The async execution only supports running dbt models (other dbt
resources, such as seeds, sources, snapshots, tests, are run using the
`ExecutionMode.LOCAL`)
- This will work only if the user provides sets `full_refresh=True` in
`operator_args` (which means tables will be dropped before being
populated, as implemented in `dbt-core`)
- Users need to use `ProfileMapping` in `ProfileConfig`, since Cosmos
relies on having the connection (credentials) to be able to run the
transformation in BQ without `dbt-core`
- Users must provide the BQ `location` in `operator_args` (this is a
limitation from the `BigQueryInsertJobOperator` that is being used to
implement the native Airflow asynchronous support)

## Testing 

We have added a new dbt project to the repository to facilitate
asynchronous task execution. The goal is to accelerate development
without disrupting or requiring fixes for the existing tests. Also, we
have added DAG for end-to-end testing
https://github.com/astronomer/astronomer-cosmos/blob/bd6657a29b111510fc34b2baf0bcc0d65ec0e5b9/dev/dags/simple_dag_async.py

## Configuration

Users need to configure the below param to execute deferrable tasks in
the Cosmos

- [ExecutionMode:
AIRFLOW_ASYNC](https://astronomer.github.io/astronomer-cosmos/getting_started/execution-modes.html)
-
[remote_target_path](https://astronomer.github.io/astronomer-cosmos/configuration/cosmos-conf.html#remote-target-path)
-
[remote_target_path_conn_id](https://astronomer.github.io/astronomer-cosmos/configuration/cosmos-conf.html#remote-target-path-conn-id)

Example DAG:
https://github.com/astronomer/astronomer-cosmos/blob/bd6657a29b111510fc34b2baf0bcc0d65ec0e5b9/dev/dags/simple_dag_async.py

## Installation

You can leverage async operator support by installing an additional
dependency
```
astronomer-cosmos[dbt-bigquery, google]
```


## Documentation 

The PR also document the limitations and uses of Airflow async execution
in the Cosmos.

## Related Issue(s)

Related to: #1120
Closes: #1134

## Breaking Change?

No

## Notes

This is an experimental feature, and as such, it may undergo breaking
changes. We encourage users to share their experiences and feedback to
improve it further.

We'd love support and feedback so we can define the next steps.

## Checklist

- [x] I have made corresponding changes to the documentation (if
required)
- [x] I have added tests that prove my fix is effective or that my
feature works

## Credits

This was a result of teamwork and effort:

Co-authored-by: Pankaj Koti <[email protected]>
Co-authored-by: Tatiana Al-Chueyr <[email protected]>

## Future Work

- Design interface to facilitate the easy addition of new asynchronous
databases operators
#1238
- Improve the test coverage
#1239
- Address the limitations (we need to log these issues)

---------

Co-authored-by: Pankaj Koti <[email protected]>
Co-authored-by: Tatiana Al-Chueyr <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2024
1 parent 93eb17e commit 111d430
Show file tree
Hide file tree
Showing 37 changed files with 1,226 additions and 118 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,3 @@ webserver_config.py

# VI
*.sw[a-z]

# Ignore possibly created symlink to `dev/dags` for running `airflow dags test` command.
dags
3 changes: 2 additions & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
Contains dags, task groups, and operators.
"""
__version__ = "1.6.0"

__version__ = "1.7.0a1"


from cosmos.airflow.dag import DbtDag
Expand Down
33 changes: 26 additions & 7 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create_task_metadata(
node: DbtNode,
execution_mode: ExecutionMode,
args: dict[str, Any],
dbt_dag_task_group_identifier: str,
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
) -> TaskMetadata | None:
Expand All @@ -142,6 +143,7 @@ def create_task_metadata(
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param args: Arguments to be used to instantiate an Airflow Task
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
Expand All @@ -156,7 +158,10 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {"dbt_node_config": node.context_dict}
extra_context = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand Down Expand Up @@ -226,6 +231,7 @@ def generate_task_or_group(
node=node,
execution_mode=execution_mode,
args=task_args,
dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group),
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
)
Expand Down Expand Up @@ -268,14 +274,28 @@ def _add_dbt_compile_task(
id=DBT_COMPILE_TASK_ID,
operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group)

for task_id, task in tasks_map.items():
if not task.upstream_list:
compile_airflow_task >> task

tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task

for node_id, node in nodes.items():
if not node.depends_on and node_id in tasks_map:
tasks_map[DBT_COMPILE_TASK_ID] >> tasks_map[node_id]

def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -> str:
dag_id = dag.dag_id
task_group_id = task_group.group_id if task_group else None
identifiers_list = []
if dag_id:
identifiers_list.append(dag_id)
if task_group_id:
identifiers_list.append(task_group_id)
dag_task_group_identifier = "__".join(identifiers_list)

return dag_task_group_identifier


def build_airflow_graph(
Expand Down Expand Up @@ -358,9 +378,8 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)

create_airflow_task_dependencies(nodes, tasks_map)
_add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group)


def create_airflow_task_dependencies(
Expand Down
16 changes: 16 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,21 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self) -> str:
if isinstance(self.profile_mapping, BaseProfileMapping):
return str(self.profile_mapping.dbt_profile_type)

profile_path = self._get_profile_path()

with open(profile_path) as file:
profiles = yaml.safe_load(file)

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return str(target_type)

return "undefined"

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
3 changes: 2 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
from typing import Any

from airflow.models import BaseOperator
from airflow.models.dag import DAG
Expand All @@ -27,7 +28,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
module = importlib.import_module(module_name)
Operator = getattr(module, class_name)

task_kwargs = {}
task_kwargs: dict[str, Any] = {}
if task.owner != "":
task_kwargs["owner"] = task.owner

Expand Down
1 change: 0 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
179 changes: 151 additions & 28 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,190 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]

class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass
from abc import ABCMeta


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass
from airflow.models.baseoperator import BaseOperator


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass
class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
super().__init__(**kwargs)


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
class DbtBuildAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtBuildLocalOperator): # type: ignore
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator): # type: ignore
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalOperator): # type: ignore
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.extra_context = extra_context or {}
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.
# We need to pop them.
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys())
non_async_args -= {"task_id"}

for arg_key, arg_value in kwargs.items():
if arg_key not in non_async_args:
clean_kwargs[arg_key] = arg_value

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
**clean_kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING:
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
class DbtCompileAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCompileLocalOperator): # type: ignore
pass
Loading

0 comments on commit 111d430

Please sign in to comment.