Skip to content

Commit

Permalink
remove resolver and properly load plugins for airflow and py4j (cronh…
Browse files Browse the repository at this point in the history
…elper) integration
  • Loading branch information
stikkireddy committed Aug 7, 2023
1 parent 62dee19 commit 50ad7b5
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 98 deletions.
2 changes: 0 additions & 2 deletions brickflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def get_bundles_project_env() -> str:
from brickflow.engine.compute import Cluster, Runtimes
from brickflow.engine.project import Project
from brickflow.resolver import (
RelativePathPackageResolver,
get_relative_path_to_brickflow_root,
)

Expand Down Expand Up @@ -229,7 +228,6 @@ def get_bundles_project_env() -> str:
"BrickflowDefaultEnvs",
"get_default_log_handler",
"get_brickflow_version",
"RelativePathPackageResolver",
"BrickflowProjectConstants",
]

Expand Down
46 changes: 41 additions & 5 deletions brickflow/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import functools
import inspect
import logging
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -370,7 +371,7 @@ def task_execute(task: "Task", workflow: "Workflow") -> TaskResponse:


@functools.lru_cache
def get_brickflow_tasks_hook() -> BrickflowTaskPluginSpec:
def get_plugin_manager() -> pluggy.PluginManager:
pm = pluggy.PluginManager(BRICKFLOW_TASK_PLUGINS)
pm.add_hookspecs(BrickflowTaskPluginSpec)
pm.load_setuptools_entrypoints(BRICKFLOW_TASK_PLUGINS)
Expand All @@ -381,7 +382,22 @@ def get_brickflow_tasks_hook() -> BrickflowTaskPluginSpec:
name,
plugin_instance.__class__.__name__,
)
return pm.hook
return pm


@functools.lru_cache
def get_brickflow_tasks_hook() -> BrickflowTaskPluginSpec:
try:
from brickflow_plugins import load_plugins

load_plugins()
except ImportError as e:
_ilog.info(
"If you need airflow support: brickflow extras not installed "
"please pip install brickflow[airflow] and py4j! Error: %s",
str(e.msg),
)
return get_plugin_manager().hook


@dataclass(frozen=True)
Expand All @@ -397,6 +413,7 @@ class Task:
trigger_rule: BrickflowTriggerRule = BrickflowTriggerRule.ALL_SUCCESS
task_settings: Optional[TaskSettings] = None
custom_execute_callback: Optional[Callable] = None
ensure_brickflow_plugins: bool = False

def __post_init__(self) -> None:
self.is_valid_task_signature()
Expand Down Expand Up @@ -466,6 +483,19 @@ def get_obj_dict(self, entrypoint: str) -> Dict[str, Any]:
},
}

def _ensure_brickflow_plugins(self) -> None:
if self.ensure_brickflow_plugins is False:
return
try:
import brickflow_plugins # noqa
except ImportError as e:
raise ImportError(
f"Brickflow Plugins not available for task: {self.name}. "
"If you need airflow support: brickflow extras not installed "
"please pip install brickflow[airflow] and py4j! Error: %s",
str(e.msg),
)

# TODO: error if star isn't there
def is_valid_task_signature(self) -> None:
# only supports kwonlyargs with defaults
Expand Down Expand Up @@ -571,15 +601,21 @@ def _skip_because_not_selected(self) -> Tuple[bool, Optional[str]]:
return False, None

@with_brickflow_logger
def execute(self) -> Any:
def execute(self, ignore_all_deps: bool = False) -> Any:
# Workflow is:
# 1. Check to see if there selected tasks and if there are is this task in the list
# 2. Check to see if the previous task is skipped and trigger rule.
# 3. Check to see if this a custom python task and execute it
# 4. Execute the task function
_ilog.setLevel(logging.INFO) # enable logging for task execution
ctx._set_current_task(self.name)
self._ensure_brickflow_plugins() # if you are expecting brickflow plugins to be installed
if ignore_all_deps is True:
_ilog.info(
"Ignoring all dependencies for task: %s due to debugging", self.name
)
_select_task_skip, _select_task_skip_reason = self._skip_because_not_selected()
if _select_task_skip is True:
if _select_task_skip is True and ignore_all_deps is False:
# check if this task is skipped due to task selection
_ilog.info(
"Skipping task... %s for reason: %s",
Expand All @@ -589,7 +625,7 @@ def execute(self) -> Any:
ctx._reset_current_task()
return
_skip, reason = self.should_skip()
if _skip is True:
if _skip is True and ignore_all_deps is False:
_ilog.info("Skipping task... %s for reason: %s", self.name, reason)
ctx.task_coms.put(self.name, BRANCH_SKIP_EXCEPT, SKIP_EXCEPT_HACK)
ctx._reset_current_task()
Expand Down
11 changes: 11 additions & 0 deletions brickflow/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class Workflow:
run_as_service_principal: Optional[str] = None
# this a databricks limit set on workflows, you can override it if you have exception
max_tasks_in_workflow: int = 100
ensure_brickflow_plugins: Optional[bool] = None

def __post_init__(self) -> None:
self.graph.add_node(ROOT_NODE)
Expand Down Expand Up @@ -262,6 +263,7 @@ def _add_task(
trigger_rule: BrickflowTriggerRule = BrickflowTriggerRule.ALL_SUCCESS,
custom_execute_callback: Optional[Callable] = None,
task_settings: Optional[TaskSettings] = None,
ensure_brickflow_plugins: bool = False,
) -> None:
if self.task_exists(task_id):
raise TaskAlreadyExistsError(
Expand All @@ -279,6 +281,12 @@ def _add_task(
if isinstance(depends_on, str) or callable(depends_on)
else depends_on
)

if self.ensure_brickflow_plugins is not None:
ensure_plugins = self.ensure_brickflow_plugins
else:
ensure_plugins = ensure_brickflow_plugins

self.tasks[task_id] = Task(
task_id=task_id,
task_func=f,
Expand All @@ -291,6 +299,7 @@ def _add_task(
trigger_rule=trigger_rule,
task_settings=task_settings,
custom_execute_callback=custom_execute_callback,
ensure_brickflow_plugins=ensure_plugins,
)

# attempt to create task object before adding to graph
Expand Down Expand Up @@ -337,6 +346,7 @@ def task(
trigger_rule: BrickflowTriggerRule = BrickflowTriggerRule.ALL_SUCCESS,
custom_execute_callback: Optional[Callable] = None,
task_settings: Optional[TaskSettings] = None,
ensure_brickflow_plugins: bool = False,
) -> Callable:
if len(self.tasks) >= self.max_tasks_in_workflow:
raise ValueError(
Expand All @@ -358,6 +368,7 @@ def task_wrapper(f: Callable) -> Callable:
trigger_rule=trigger_rule,
custom_execute_callback=custom_execute_callback,
task_settings=task_settings,
ensure_brickflow_plugins=ensure_brickflow_plugins,
)

@functools.wraps(f)
Expand Down
34 changes: 1 addition & 33 deletions brickflow/resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
from pathlib import Path
from typing import Union, Dict, Any, List, Optional
from typing import Union, Any, List, Optional
import pathlib

from brickflow import BrickflowProjectConstants, _ilog, ctx
Expand Down Expand Up @@ -88,35 +88,3 @@ def get_notebook_ws_path(dbutils: Optional[Any]) -> Optional[str]:
)
)
return None


class RelativePathPackageResolver:
@staticmethod
def _get_current_file_path(global_vars: Dict[str, Any]) -> str:
if "dbutils" in global_vars:
ws_path = get_notebook_ws_path(global_vars["dbutils"])
if ws_path is None:
raise ValueError("Unable to resolve notebook path.")
return ws_path
else:
return global_vars["__file__"]

@staticmethod
def add_relative_path(
global_vars: Dict[str, Any],
current_file_to_root: str,
root_to_module: str = ".",
) -> None:
# root to module must always be relative to the root of the project (i.e. must not start with "/")
if root_to_module.startswith("/"):
raise ValueError(
f"root_to_module must be relative to the root of the project. "
f"It must not start with '/'. root_to_module: {root_to_module}"
)
p = (
Path(RelativePathPackageResolver._get_current_file_path(global_vars)).parent
/ Path(current_file_to_root)
/ root_to_module
)
path = p.resolve()
add_to_sys_path(path)
13 changes: 13 additions & 0 deletions brickflow_plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,23 @@ def setup_logger():
WorkflowDependencySensor,
)

from brickflow_plugins.airflow.cronhelper import cron_helper


def load_plugins():
from brickflow.engine.task import get_plugin_manager
from brickflow_plugins.airflow.brickflow_task_plugin import (
AirflowOperatorBrickflowTaskPluginImpl,
)

get_plugin_manager().register(AirflowOperatorBrickflowTaskPluginImpl())


__all__: List[str] = [
"TaskDependencySensor",
"BashOperator",
"BranchPythonOperator",
"ShortCircuitOperator",
"WorkflowDependencySensor",
"load_plugins",
]
10 changes: 9 additions & 1 deletion brickflow_plugins/airflow/cronhelper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import os
from pathlib import Path

from py4j.protocol import Py4JError
try:
from py4j.protocol import Py4JError
except ImportError:
raise ImportError(
"You must install py4j to use cronhelper, "
"please try pip install py4j. "
"This library is not installed as "
"it is provided by databricks OOTB."
)


class CronHelper:
Expand Down
57 changes: 0 additions & 57 deletions tests/test_brickflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# pylint: disable=unused-import
from pathlib import Path
from unittest.mock import patch

import pytest

from brickflow.resolver import RelativePathPackageResolver


def test_imports():
Expand Down Expand Up @@ -44,54 +38,3 @@ def test_imports():
print("All imports Succeeded")
except ImportError as e:
print(f"Import failed: {e}")


def test_path_resolver():
with patch("brickflow.resolver.add_to_sys_path") as mock_add_to_sys_path, patch(
"brickflow.RelativePathPackageResolver._get_current_file_path"
) as mock_get_current_file_path:
mock_add_to_sys_path.return_value = None
mock_get_current_file_path.return_value = "/Some/Fake/Path/file.py"

# go up a directory and use the same
RelativePathPackageResolver.add_relative_path(
globals(), current_file_to_root="../", root_to_module="."
)

# Assertions
mock_add_to_sys_path.assert_called_once_with(Path("/Some/Fake"))
assert mock_get_current_file_path.called


def test_path_resolver_complex():
with patch("brickflow.resolver.add_to_sys_path") as mock_add_to_sys_path, patch(
"brickflow.RelativePathPackageResolver._get_current_file_path"
) as mock_get_current_file_path:
mock_add_to_sys_path.return_value = None
mock_get_current_file_path.return_value = "/Some/Fake/Path/file.py"

# go up 2 directories and then to /some/module
RelativePathPackageResolver.add_relative_path(
globals(), current_file_to_root="../../", root_to_module="./some/module"
)

# Assertions
mock_add_to_sys_path.assert_called_once_with(Path("/Some/some/module"))
assert mock_get_current_file_path.called


def test_path_resolver_root_to_module_abs():
with patch("brickflow.resolver.add_to_sys_path") as mock_add_to_sys_path, patch(
"brickflow.RelativePathPackageResolver._get_current_file_path"
) as mock_get_current_file_path:
mock_add_to_sys_path.return_value = None
mock_get_current_file_path.return_value = "/Some/Fake/Path"

# go up 2 directories and then to /some/module
with pytest.raises(
ValueError,
match="root_to_module must be relative to the root of the project",
):
RelativePathPackageResolver.add_relative_path(
globals(), current_file_to_root="../../", root_to_module="/some/module"
)

0 comments on commit 50ad7b5

Please sign in to comment.