Skip to content

Commit

Permalink
[FEATURE] Add WorkflowTaskDependencySensor (#110)
Browse files Browse the repository at this point in the history
* WorkflowTaskDependencySensor

* update docs

* fix request params

* add `get_job_id` to utils

* update utils and tests

---------

Co-authored-by: Maxim Mityutko <[email protected]>
  • Loading branch information
maxim-mityutko and Maxim Mityutko authored May 6, 2024
1 parent 2450bf4 commit 2b70616
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 17 deletions.
49 changes: 48 additions & 1 deletion brickflow/engine/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import functools
from typing import Callable, Type, List, Iterator
from typing import Callable, Type, List, Iterator, Union

from pydantic import SecretStr
from databricks.sdk import WorkspaceClient

from brickflow.context import ctx
from brickflow.hints import propagate_hint


Expand Down Expand Up @@ -28,3 +32,46 @@ def _property_iter() -> Iterator[str]:
yield k

return list(_property_iter())


def get_job_id(
job_name: str, host: Union[str, None] = None, token: Union[str, SecretStr] = None
) -> Union[str, None]:
"""
Get the job id from the specified Databricks workspace for a given job name.
Parameters
----------
job_name: str
Job name (case-insensitive)
host: str
Databricks workspace URL
token: str
Databricks API token
Returns
-------
str
Databricks job id
"""
ctx.log.info("Searching job id for job name: %s", job_name)

if host:
host = host.rstrip("/")
token = token.get_secret_value() if isinstance(token, SecretStr) else token

workspace_obj = WorkspaceClient(host=host, token=token)
jobs_list = workspace_obj.jobs.list(name=job_name)

try:
for job in jobs_list:
ctx.log.info("Job id for job '%s' is %s", job_name, job.job_id)
return job.job_id
else: # pylint: disable=useless-else-on-loop
raise ValueError
except ValueError:
raise ValueError(f"No job found with name {job_name}")
except Exception as e:
ctx.log.info("An error occurred: %s", e)

return None
2 changes: 2 additions & 0 deletions brickflow_plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setup_logger():
)
from brickflow_plugins.databricks.workflow_dependency_sensor import (
WorkflowDependencySensor,
WorkflowTaskDependencySensor,
)
from brickflow_plugins.databricks.uc_to_snowflake_operator import (
SnowflakeOperator,
Expand Down Expand Up @@ -69,6 +70,7 @@ def ensure_installation():
"BranchPythonOperator",
"ShortCircuitOperator",
"WorkflowDependencySensor",
"WorkflowTaskDependencySensor",
"SnowflakeOperator",
"UcToSnowflakeOperator",
"TableauRefreshDataSourceOperator",
Expand Down
135 changes: 120 additions & 15 deletions brickflow_plugins/databricks/workflow_dependency_sensor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import functools
import logging
import os
import time
from datetime import datetime, timedelta
from typing import Union
from warnings import warn

import requests
import time
from pydantic import SecretStr
from requests.adapters import HTTPAdapter
from databricks.sdk import WorkspaceClient

from brickflow.context import ctx
from brickflow.engine.utils import get_job_id


class WorkflowDependencySensorException(Exception):
Expand Down Expand Up @@ -43,13 +46,13 @@ def __init__(
self,
databricks_host: str,
databricks_token: Union[str, SecretStr],
dependency_job_id: int,
delta: timedelta,
timeout_seconds: int,
dependency_job_id: int = None,
dependency_job_name: str = None,
poke_interval_seconds: int = 60,
):
self.databricks_host = databricks_host
self.dependency_job_id = dependency_job_id
self.databricks_token = (
databricks_token
if isinstance(databricks_token, SecretStr)
Expand All @@ -61,6 +64,25 @@ def __init__(
self.log = logging
self.start_time = time.time()

if dependency_job_id:
warn(
"Please use 'dependency_job_name' instead of 'dependency_job_id'",
DeprecationWarning,
stacklevel=2,
)

if not dependency_job_id and not dependency_job_name:
raise WorkflowDependencySensorException(
"Either dependency_job_id or dependency_job_name should be provided"
)

self.dependency_job_id = dependency_job_id
self.dependency_job_name = dependency_job_name

self._workspace_obj = WorkspaceClient(
host=self.databricks_host, token=self.databricks_token.get_secret_value()
)

def get_retry_class(self, max_retries):
from urllib3 import Retry

Expand Down Expand Up @@ -94,24 +116,18 @@ def get_http_session(self):
session.mount("http://", HTTPAdapter(max_retries=retries))
return session

def get_execution_start_time_unix_miliseconds(self) -> int:
session = self.get_http_session()
url = f"{self.databricks_host.rstrip('/')}/api/2.1/jobs/runs/get"
headers = {
"Authorization": f"Bearer {self.databricks_token.get_secret_value()}",
"Content-Type": "application/json",
}
def get_execution_start_time_unix_milliseconds(self) -> int:
run_id = ctx.dbutils_widget_get_or_else("brickflow_parent_run_id", None)
if run_id is None:
raise WorkflowDependencySensorException(
"run_id is empty, brickflow_parent_run_id parameter is not found "
"or no value present"
)
params = {"run_id": run_id}
resp = session.get(url, params=params, headers=headers).json()

# Convert Unix timestamp in miliseconds to datetime object to easily incorporate the delta
start_time = datetime.fromtimestamp(resp["start_time"] / 1000)
run = self._workspace_obj.jobs.get_run(run_id=run_id)

# Convert Unix timestamp in milliseconds to datetime object to easily incorporate the delta
start_time = datetime.fromtimestamp(run.start_time / 1000)
execution_start_time = start_time - self.delta

# Convert datetime object back to Unix timestamp in miliseconds
Expand All @@ -128,7 +144,18 @@ def get_execution_start_time_unix_miliseconds(self) -> int:
)
return execution_start_time_unix_miliseconds

@property
def _get_job_id(self):
return get_job_id(
host=self.databricks_host,
token=self.databricks_token,
job_name=self.dependency_job_name,
)

def execute(self):
if not self.dependency_job_id:
self.dependency_job_id = self._get_job_id

session = self.get_http_session()
url = f"{self.databricks_host.rstrip('/')}/api/2.1/jobs/runs/list"
headers = {
Expand All @@ -138,7 +165,7 @@ def execute(self):
params = {
"limit": 25,
"job_id": self.dependency_job_id,
"start_time_from": self.get_execution_start_time_unix_miliseconds(),
"start_time_from": self.get_execution_start_time_unix_milliseconds(),
}

while True:
Expand Down Expand Up @@ -171,3 +198,81 @@ def execute(self):

self.log.info(f"sleeping for: {self.poke_interval}")
time.sleep(self.poke_interval)


class WorkflowTaskDependencySensor(WorkflowDependencySensor):
"""
This is used to have dependencies on the specific task within a databricks workflow
Example Usage in your brickflow task:
service_principle_pat = ctx.dbutils.secrets.get("scope", "service_principle_id")
WorkflowDependencySensor(
databricks_host=https://your_workspace_url.cloud.databricks.com,
databricks_token=service_principle_pat,
dependency_job_id=job_id,
dependency_task_name="foo",
poke_interval=20,
timeout=60,
delta=timedelta(days=1)
)
In the above snippet Databricks secrets are used as a secure service to store the databricks token.
If you get your token from another secret management service, like AWS Secrets Manager, GCP Secret Manager
or Azure Key Vault, just pass it in the databricks_token argument.
"""

def __init__(
self,
dependency_job_name: str,
dependency_task_name: str,
delta: timedelta,
timeout_seconds: int,
databricks_host: str = None,
databricks_token: Union[str, SecretStr] = None,
poke_interval_seconds: int = 60,
):
super().__init__(
databricks_host=databricks_host,
databricks_token=databricks_token,
dependency_job_name=dependency_job_name,
delta=delta,
timeout_seconds=timeout_seconds,
poke_interval_seconds=poke_interval_seconds,
)

self.dependency_task_name = dependency_task_name

def execute(self):
self.dependency_job_id = self._get_job_id

while True:
runs_list = self._workspace_obj.jobs.list_runs(
job_id=self.dependency_job_id,
limit=25,
start_time_from=self.get_execution_start_time_unix_milliseconds(),
expand_tasks=True,
)

for run in runs_list:
for task in run.tasks:
if task.task_key == self.dependency_task_name:
task_state = task.state.result_state
self.log.info(
f"Found the run_id '{run.run_id}' and '{self.dependency_task_name}' "
f"task with state: {task_state.value}"
)
if task_state.value == "SUCCESS":
self.log.info(f"Found a successful run: {run.run_id}")
return

self.log.info("Didn't find a successful task run yet...")

if (
self.timeout is not None
and (time.time() - self.start_time) > self.timeout
):
raise WorkflowDependencySensorTimeOutException(
f"The job has timed out..."
)

self.log.info(f"Sleeping for: {self.poke_interval}")
time.sleep(self.poke_interval)
3 changes: 3 additions & 0 deletions docs/api/workflow_dependency_sensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ search:
::: brickflow_plugins.databricks.workflow_dependency_sensor
handler: python
options:
members:
- WorkflowDependencySensor
- WorkflowTaskDependencySensor
filters:
- "!^_[^_]"
- "!^__[^__]"
25 changes: 25 additions & 0 deletions docs/faq/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ def wait_on_workflow(*args):
)
sensor.execute()
```

## How do I wait for a specific task in a workflow to finish before kicking off my own workflow's tasks?

```python
from brickflow.context import ctx
from brickflow_plugins import WorkflowTaskDependencySensor

wf = Workflow(...)


@wf.task
def wait_on_workflow(*args):
api_token_key = ctx.dbutils.secrets.get("scope", "api_token_key")
sensor = WorkflowTaskDependencySensor(
databricks_host="https://your_workspace_url.cloud.databricks.com",
databricks_token=api_token_key,
dependency_job_id=job_id,
dependency_task_name="foo",
poke_interval=20,
timeout=60,
delta=timedelta(days=1)
)
sensor.execute()
```

## How do I run a sql query on snowflake from DBX?
```python
from brickflow_plugins import SnowflakeOperator
Expand Down
26 changes: 26 additions & 0 deletions docs/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,32 @@ def wait_on_workflow(*args):
sensor.execute()
```

#### Workflow Task Dependency Sensor

Wait for a specific task in a workflow to finish before kicking off the current workflow's tasks

```python title="workflow_dependency_sensor"
from brickflow.context import ctx
from brickflow_plugins import WorkflowTaskDependencySensor

wf = Workflow(...)


@wf.task
def wait_on_workflow(*args):
api_token_key = ctx.dbutils.secrets.get("scope", "api_token_key")
sensor = WorkflowTaskDependencySensor(
databricks_host="https://your_workspace_url.cloud.databricks.com",
databricks_token=api_token_key,
dependency_job_id=job_id,
dependency_task_name="foo",
poke_interval=20,
timeout=60,
delta=timedelta(days=1)
)
sensor.execute()
```

#### Snowflake Operator

run snowflake queries from the databricks environment
Expand Down
19 changes: 18 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ apache-airflow = "^2.7.3"
snowflake = "^0.6.0"
tableauserverclient = "^0.25"
watchdog = "<4.0.0"
requests-mock = "1.12.1"

[tool.poetry.group.docs.dependencies]
mdx-include = ">=1.4.1,<2.0.0"
Expand Down
Empty file.
Loading

0 comments on commit 2b70616

Please sign in to comment.