diff --git a/lakefs_provider/example_dags/lakefs-dag.py b/lakefs_provider/example_dags/lakefs-dag.py index 60c5f8f..09e4467 100644 --- a/lakefs_provider/example_dags/lakefs-dag.py +++ b/lakefs_provider/example_dags/lakefs-dag.py @@ -10,6 +10,7 @@ from airflow.decorators import dag from airflow.utils.dates import days_ago from airflow.exceptions import AirflowFailException +from airflow.utils.task_group import TaskGroup from lakefs_provider.hooks.lakefs_hook import LakeFSHook from lakefs_provider.operators.create_branch_operator import LakeFSCreateBranchOperator @@ -87,21 +88,19 @@ def lakeFS_workflow(): - extra: {"access_key_id":"AKIAIOSFODNN7EXAMPLE","secret_access_key":"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"} """ - # Create the branch to run on - task_create_branch = LakeFSCreateBranchOperator( - task_id='create_branch', - source_branch=default_args.get('default-branch') - ) + with TaskGroup(group_id="create_branch") as create_branch: + # Create the branch to run on + task_create_branch = LakeFSCreateBranchOperator( + task_id='create_branch', + source_branch=default_args.get('default-branch') + ) - # Create a path. - task_create_file = LakeFSUploadOperator( - task_id='upload_file', - content=NamedStringIO(content=f"{CONTENT_PREFIX} @{time.asctime()}", name='content')) + task_get_branch_commit = LakeFSGetCommitOperator( + do_xcom_push=True, + task_id='get_branch_commit', + ref=default_args['branch']) - task_get_branch_commit = LakeFSGetCommitOperator( - do_xcom_push=True, - task_id='get_branch_commit', - ref=default_args['branch']) + task_create_branch >> task_get_branch_commit # Checks periodically for the path. # DAG continues only when the file exists. @@ -112,23 +111,31 @@ def lakeFS_workflow(): timeout=10, ) - # Commit the changes to the branch. - # (Also a good place to validate the new changes before committing them) - task_commit = LakeFSCommitOperator( - task_id='commit', - msg=COMMIT_MESSAGE_1, - metadata={"committed_from": "airflow-operator"} - ) + with TaskGroup(group_id="upload_and_commit") as upload_and_commit: + # Create a path. + task_create_file = LakeFSUploadOperator( + task_id='upload_file', + content=NamedStringIO(content=f"{CONTENT_PREFIX} @{time.asctime()}", name='content')) + + # Commit the changes to the branch. + # (Also a good place to validate the new changes before committing them) + task_commit = LakeFSCommitOperator( + task_id='commit', + msg=COMMIT_MESSAGE_1, + metadata={"committed_from": "airflow-operator"} + ) - # Create symlink file for example-branch - task_create_symlink = LakeFSCreateSymlinkOperator(task_id="create_symlink") + # Create symlink file for example-branch + task_create_symlink = LakeFSCreateSymlinkOperator(task_id="create_symlink") + + task_create_file >> task_commit >> task_create_symlink # Wait until the commit is completed. # Not really necessary in this DAG, since the LakeFSCommitOperator won't return before that. # Nonetheless we added it to show the full capabilities. task_sense_commit = LakeFSCommitSensor( task_id='sense_commit', - prev_commit_id='''{{ task_instance.xcom_pull(task_ids='get_branch_commit', key='return_value').id }}''', + prev_commit_id='''{{ task_instance.xcom_pull(task_ids='create_branch.get_branch_commit', key='return_value').id }}''', mode='reschedule', poke_interval=1, timeout=10, @@ -149,18 +156,21 @@ def lakeFS_workflow(): 'expected': CONTENT_PREFIX, }) - # Merge the changes back to the main branch. - task_merge = LakeFSMergeOperator( - task_id='merge_branches', - do_xcom_push=True, - source_ref=default_args.get('branch'), - destination_branch=default_args.get('default-branch'), - msg=MERGE_MESSAGE_1, - metadata={"committer": "airflow-operator"} - ) - - expectedCommits = ['''{{ ti.xcom_pull('merge_branches') }}''', - '''{{ ti.xcom_pull('commit') }}'''] + with TaskGroup(group_id="merge_branch") as merge_branch: + # Merge the changes back to the main branch. + task_merge = LakeFSMergeOperator( + task_id='merge_branches', + do_xcom_push=True, + source_ref=default_args.get('branch'), + destination_branch=default_args.get('default-branch'), + msg=MERGE_MESSAGE_1, + metadata={"committer": "airflow-operator"} + ) + + task_merge + + expectedCommits = ['''{{ ti.xcom_pull('merge_branch.merge_branches') }}''', + '''{{ ti.xcom_pull('upload_and_commit.commit') }}'''] expectedMessages = [MERGE_MESSAGE_1, COMMIT_MESSAGE_1] # Fetch and verify log messages in bulk. @@ -169,7 +179,7 @@ def lakeFS_workflow(): python_callable=check_logs, op_kwargs={ 'repo': default_args.get('repo'), - 'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branches', key='return_value') }}''', + 'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branch.merge_branches', key='return_value') }}''', 'commits': expectedCommits, 'messages': expectedMessages, }) @@ -180,17 +190,16 @@ def lakeFS_workflow(): python_callable=check_logs, op_kwargs= { 'repo': default_args.get('repo'), - 'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branches', key='return_value') }}''', + 'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branch.merge_branches', key='return_value') }}''', 'amount': 1, 'commits': expectedCommits, 'messages': expectedMessages, }) - task_create_branch >> task_get_branch_commit >> [task_create_file, task_sense_commit, task_sense_file] - task_create_file >> task_commit >> task_create_symlink + create_branch >> [upload_and_commit, task_sense_commit, task_sense_file] task_sense_file >> task_get_file >> task_check_contents - task_sense_commit >> task_merge >> [task_check_logs_bulk, task_check_logs_individually] + task_sense_commit >> merge_branch >> [task_check_logs_bulk, task_check_logs_individually] sample_workflow_dag = lakeFS_workflow()