Skip to content

Commit

Permalink
Add TaskGroups to confirm that LakeFSCommitOperator, LakeFSGetCommitO…
Browse files Browse the repository at this point in the history
…perator, LakeFSMergeOperator, and LakeFSUploadOperator works inside of TaskGroups

Signed-off-by: Fredrik Bakken <[email protected]>
  • Loading branch information
FredrikBakken committed Aug 13, 2023
1 parent 8b44de3 commit 97ea5c3
Showing 1 changed file with 49 additions and 40 deletions.
89 changes: 49 additions & 40 deletions lakefs_provider/example_dags/lakefs-dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from airflow.decorators import dag
from airflow.utils.dates import days_ago
from airflow.exceptions import AirflowFailException
from airflow.utils.task_group import TaskGroup

from lakefs_provider.hooks.lakefs_hook import LakeFSHook
from lakefs_provider.operators.create_branch_operator import LakeFSCreateBranchOperator
Expand Down Expand Up @@ -87,21 +88,19 @@ def lakeFS_workflow():
- extra: {"access_key_id":"AKIAIOSFODNN7EXAMPLE","secret_access_key":"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"}
"""

# Create the branch to run on
task_create_branch = LakeFSCreateBranchOperator(
task_id='create_branch',
source_branch=default_args.get('default-branch')
)
with TaskGroup(group_id="create_branch") as create_branch:
# Create the branch to run on
task_create_branch = LakeFSCreateBranchOperator(
task_id='create_branch',
source_branch=default_args.get('default-branch')
)

# Create a path.
task_create_file = LakeFSUploadOperator(
task_id='upload_file',
content=NamedStringIO(content=f"{CONTENT_PREFIX} @{time.asctime()}", name='content'))
task_get_branch_commit = LakeFSGetCommitOperator(
do_xcom_push=True,
task_id='get_branch_commit',
ref=default_args['branch'])

task_get_branch_commit = LakeFSGetCommitOperator(
do_xcom_push=True,
task_id='get_branch_commit',
ref=default_args['branch'])
task_create_branch >> task_get_branch_commit

# Checks periodically for the path.
# DAG continues only when the file exists.
Expand All @@ -112,23 +111,31 @@ def lakeFS_workflow():
timeout=10,
)

# Commit the changes to the branch.
# (Also a good place to validate the new changes before committing them)
task_commit = LakeFSCommitOperator(
task_id='commit',
msg=COMMIT_MESSAGE_1,
metadata={"committed_from": "airflow-operator"}
)
with TaskGroup(group_id="upload_and_commit") as upload_and_commit:
# Create a path.
task_create_file = LakeFSUploadOperator(
task_id='upload_file',
content=NamedStringIO(content=f"{CONTENT_PREFIX} @{time.asctime()}", name='content'))

# Commit the changes to the branch.
# (Also a good place to validate the new changes before committing them)
task_commit = LakeFSCommitOperator(
task_id='commit',
msg=COMMIT_MESSAGE_1,
metadata={"committed_from": "airflow-operator"}
)

# Create symlink file for example-branch
task_create_symlink = LakeFSCreateSymlinkOperator(task_id="create_symlink")
# Create symlink file for example-branch
task_create_symlink = LakeFSCreateSymlinkOperator(task_id="create_symlink")

task_create_file >> task_commit >> task_create_symlink

# Wait until the commit is completed.
# Not really necessary in this DAG, since the LakeFSCommitOperator won't return before that.
# Nonetheless we added it to show the full capabilities.
task_sense_commit = LakeFSCommitSensor(
task_id='sense_commit',
prev_commit_id='''{{ task_instance.xcom_pull(task_ids='get_branch_commit', key='return_value').id }}''',
prev_commit_id='''{{ task_instance.xcom_pull(task_ids='create_branch.get_branch_commit', key='return_value').id }}''',
mode='reschedule',
poke_interval=1,
timeout=10,
Expand All @@ -149,18 +156,21 @@ def lakeFS_workflow():
'expected': CONTENT_PREFIX,
})

# Merge the changes back to the main branch.
task_merge = LakeFSMergeOperator(
task_id='merge_branches',
do_xcom_push=True,
source_ref=default_args.get('branch'),
destination_branch=default_args.get('default-branch'),
msg=MERGE_MESSAGE_1,
metadata={"committer": "airflow-operator"}
)

expectedCommits = ['''{{ ti.xcom_pull('merge_branches') }}''',
'''{{ ti.xcom_pull('commit') }}''']
with TaskGroup(group_id="merge_branch") as merge_branch:
# Merge the changes back to the main branch.
task_merge = LakeFSMergeOperator(
task_id='merge_branches',
do_xcom_push=True,
source_ref=default_args.get('branch'),
destination_branch=default_args.get('default-branch'),
msg=MERGE_MESSAGE_1,
metadata={"committer": "airflow-operator"}
)

task_merge

expectedCommits = ['''{{ ti.xcom_pull('merge_branch.merge_branches') }}''',
'''{{ ti.xcom_pull('upload_and_commit.commit') }}''']
expectedMessages = [MERGE_MESSAGE_1, COMMIT_MESSAGE_1]

# Fetch and verify log messages in bulk.
Expand All @@ -169,7 +179,7 @@ def lakeFS_workflow():
python_callable=check_logs,
op_kwargs={
'repo': default_args.get('repo'),
'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branches', key='return_value') }}''',
'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branch.merge_branches', key='return_value') }}''',
'commits': expectedCommits,
'messages': expectedMessages,
})
Expand All @@ -180,17 +190,16 @@ def lakeFS_workflow():
python_callable=check_logs,
op_kwargs= {
'repo': default_args.get('repo'),
'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branches', key='return_value') }}''',
'ref': '''{{ task_instance.xcom_pull(task_ids='merge_branch.merge_branches', key='return_value') }}''',
'amount': 1,
'commits': expectedCommits,
'messages': expectedMessages,
})


task_create_branch >> task_get_branch_commit >> [task_create_file, task_sense_commit, task_sense_file]
task_create_file >> task_commit >> task_create_symlink
create_branch >> [upload_and_commit, task_sense_commit, task_sense_file]
task_sense_file >> task_get_file >> task_check_contents
task_sense_commit >> task_merge >> [task_check_logs_bulk, task_check_logs_individually]
task_sense_commit >> merge_branch >> [task_check_logs_bulk, task_check_logs_individually]


sample_workflow_dag = lakeFS_workflow()

0 comments on commit 97ea5c3

Please sign in to comment.