diff --git a/lakefs_provider/example_dags/lakefs-dag.py b/lakefs_provider/example_dags/lakefs-dag.py index 60c5f8f..0442e41 100644 --- a/lakefs_provider/example_dags/lakefs-dag.py +++ b/lakefs_provider/example_dags/lakefs-dag.py @@ -1,4 +1,3 @@ -from typing import Dict from typing import Sequence from collections import namedtuple @@ -11,9 +10,11 @@ from airflow.utils.dates import days_ago from airflow.exceptions import AirflowFailException +from lakefs_client.exceptions import NotFoundException from lakefs_provider.hooks.lakefs_hook import LakeFSHook from lakefs_provider.operators.create_branch_operator import LakeFSCreateBranchOperator from lakefs_provider.operators.create_symlink_operator import LakeFSCreateSymlinkOperator +from lakefs_provider.operators.delete_branch_operator import LakeFSDeleteBranchOperator from lakefs_provider.operators.merge_operator import LakeFSMergeOperator from lakefs_provider.operators.upload_operator import LakeFSUploadOperator from lakefs_provider.operators.commit_operator import LakeFSCommitOperator @@ -62,6 +63,16 @@ def check_logs(task_instance, repo: str, ref: str, commits: Sequence[str], messa raise AirflowFailException(f'Got {actual} instead of {expected}') +def check_branch_object(task_instance, repo: str, branch: str, path: str): + hook = LakeFSHook(default_args['lakefs_conn_id']) + print(f"Trying to check if the following path exists: lakefs://{repo}/{branch}/{path}") + try: + hook.get_object(repo=repo, ref=branch, path=path) + raise AirflowFailException(f"Path found, this is not to be expected.") + except NotFoundException as e: + print(f"Path not found, as expected: {e}") + + class NamedStringIO(StringIO): def __init__(self, content: str, name: str) -> None: super().__init__(content) @@ -186,11 +197,28 @@ def lakeFS_workflow(): 'messages': expectedMessages, }) + task_delete_branch = LakeFSDeleteBranchOperator( + task_id='delete_branch', + repo=default_args.get('repo'), + branch=default_args.get('branch'), + ) + + task_check_branch_object = PythonOperator( + task_id='check_branch_object', + python_callable=check_branch_object, + op_kwargs={ + 'repo': default_args.get('repo'), + 'branch': default_args.get('branch'), + 'path': default_args.get('path'), + } + ) task_create_branch >> task_get_branch_commit >> [task_create_file, task_sense_commit, task_sense_file] task_create_file >> task_commit >> task_create_symlink task_sense_file >> task_get_file >> task_check_contents task_sense_commit >> task_merge >> [task_check_logs_bulk, task_check_logs_individually] + [task_check_contents, task_check_logs_bulk, task_check_logs_individually] >> task_delete_branch + task_delete_branch >> task_check_branch_object sample_workflow_dag = lakeFS_workflow() diff --git a/lakefs_provider/hooks/lakefs_hook.py b/lakefs_provider/hooks/lakefs_hook.py index 013e670..7c3f8e9 100644 --- a/lakefs_provider/hooks/lakefs_hook.py +++ b/lakefs_provider/hooks/lakefs_hook.py @@ -155,6 +155,10 @@ def create_symlink_file(self, repo: str, branch: str, location: str = None) -> return client.metadata.create_symlink_file(repository=repo, branch=branch, **kwargs)["location"] + def delete_branch(self, repo: str, branch: str) -> str: + client = self.get_conn() + return client.branches.delete_branch(repository=repo, branch=branch) + def test_connection(self): """Test Connection""" conn = self.get_connection(self.lakefs_conn_id) diff --git a/lakefs_provider/operators/delete_branch_operator.py b/lakefs_provider/operators/delete_branch_operator.py new file mode 100644 index 0000000..efb8972 --- /dev/null +++ b/lakefs_provider/operators/delete_branch_operator.py @@ -0,0 +1,41 @@ +from typing import Any, Callable, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + +from lakefs_provider.hooks.lakefs_hook import LakeFSHook + + +class LakeFSDeleteBranchOperator(BaseOperator): + """ + Delete a lakeFS branch by calling the lakeFS server. + + :param lakefs_conn_id: connection to run the operator with + :type lakefs_conn_id: str + :param repo: The lakeFS repo where the branch is deleted. + :type repo: str + :param branch: The branch name to delete + :type branch: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields = [ + 'repo', + 'branch', + ] + template_ext = () + ui_color = '#f4a460' + + @apply_defaults + def __init__(self, lakefs_conn_id: str, repo: str, branch: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.lakefs_conn_id = lakefs_conn_id + self.repo = repo + self.branch = branch + + def execute(self, context: Dict[str, Any]) -> Any: + hook = LakeFSHook(lakefs_conn_id=self.lakefs_conn_id) + + self.log.info(f"Delete lakeFS branch {self.branch} in repo {self.repo}") + return hook.delete_branch(self.repo, self.branch)