Skip to content

Commit

Permalink
Add branch delete operator and provider test (#70)
Browse files Browse the repository at this point in the history
Signed-off-by: Fredrik Bakken <[email protected]>
  • Loading branch information
FredrikBakken authored Aug 14, 2023
1 parent ea452e4 commit 4f3e790
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
30 changes: 29 additions & 1 deletion lakefs_provider/example_dags/lakefs-dag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Dict
from typing import Sequence

from collections import namedtuple
Expand All @@ -11,9 +10,11 @@
from airflow.utils.dates import days_ago
from airflow.exceptions import AirflowFailException

from lakefs_client.exceptions import NotFoundException
from lakefs_provider.hooks.lakefs_hook import LakeFSHook
from lakefs_provider.operators.create_branch_operator import LakeFSCreateBranchOperator
from lakefs_provider.operators.create_symlink_operator import LakeFSCreateSymlinkOperator
from lakefs_provider.operators.delete_branch_operator import LakeFSDeleteBranchOperator
from lakefs_provider.operators.merge_operator import LakeFSMergeOperator
from lakefs_provider.operators.upload_operator import LakeFSUploadOperator
from lakefs_provider.operators.commit_operator import LakeFSCommitOperator
Expand Down Expand Up @@ -62,6 +63,16 @@ def check_logs(task_instance, repo: str, ref: str, commits: Sequence[str], messa
raise AirflowFailException(f'Got {actual} instead of {expected}')


def check_branch_object(task_instance, repo: str, branch: str, path: str):
hook = LakeFSHook(default_args['lakefs_conn_id'])
print(f"Trying to check if the following path exists: lakefs://{repo}/{branch}/{path}")
try:
hook.get_object(repo=repo, ref=branch, path=path)
raise AirflowFailException(f"Path found, this is not to be expected.")
except NotFoundException as e:
print(f"Path not found, as expected: {e}")


class NamedStringIO(StringIO):
def __init__(self, content: str, name: str) -> None:
super().__init__(content)
Expand Down Expand Up @@ -186,11 +197,28 @@ def lakeFS_workflow():
'messages': expectedMessages,
})

task_delete_branch = LakeFSDeleteBranchOperator(
task_id='delete_branch',
repo=default_args.get('repo'),
branch=default_args.get('branch'),
)

task_check_branch_object = PythonOperator(
task_id='check_branch_object',
python_callable=check_branch_object,
op_kwargs={
'repo': default_args.get('repo'),
'branch': default_args.get('branch'),
'path': default_args.get('path'),
}
)

task_create_branch >> task_get_branch_commit >> [task_create_file, task_sense_commit, task_sense_file]
task_create_file >> task_commit >> task_create_symlink
task_sense_file >> task_get_file >> task_check_contents
task_sense_commit >> task_merge >> [task_check_logs_bulk, task_check_logs_individually]
[task_check_contents, task_check_logs_bulk, task_check_logs_individually] >> task_delete_branch
task_delete_branch >> task_check_branch_object


sample_workflow_dag = lakeFS_workflow()
4 changes: 4 additions & 0 deletions lakefs_provider/hooks/lakefs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def create_symlink_file(self, repo: str, branch: str, location: str = None) ->

return client.metadata.create_symlink_file(repository=repo, branch=branch, **kwargs)["location"]

def delete_branch(self, repo: str, branch: str) -> str:
client = self.get_conn()
return client.branches.delete_branch(repository=repo, branch=branch)

def test_connection(self):
"""Test Connection"""
conn = self.get_connection(self.lakefs_conn_id)
Expand Down
41 changes: 41 additions & 0 deletions lakefs_provider/operators/delete_branch_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Callable, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults

from lakefs_provider.hooks.lakefs_hook import LakeFSHook


class LakeFSDeleteBranchOperator(BaseOperator):
"""
Delete a lakeFS branch by calling the lakeFS server.
:param lakefs_conn_id: connection to run the operator with
:type lakefs_conn_id: str
:param repo: The lakeFS repo where the branch is deleted.
:type repo: str
:param branch: The branch name to delete
:type branch: str
"""

# Specify the arguments that are allowed to parse with jinja templating
template_fields = [
'repo',
'branch',
]
template_ext = ()
ui_color = '#f4a460'

@apply_defaults
def __init__(self, lakefs_conn_id: str, repo: str, branch: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.lakefs_conn_id = lakefs_conn_id
self.repo = repo
self.branch = branch

def execute(self, context: Dict[str, Any]) -> Any:
hook = LakeFSHook(lakefs_conn_id=self.lakefs_conn_id)

self.log.info(f"Delete lakeFS branch {self.branch} in repo {self.repo}")
return hook.delete_branch(self.repo, self.branch)

0 comments on commit 4f3e790

Please sign in to comment.