diff --git a/brickflow/engine/utils.py b/brickflow/engine/utils.py index 69863e77..c1173d6b 100644 --- a/brickflow/engine/utils.py +++ b/brickflow/engine/utils.py @@ -1,6 +1,10 @@ import functools -from typing import Callable, Type, List, Iterator +from typing import Callable, Type, List, Iterator, Union +from pydantic import SecretStr +from databricks.sdk import WorkspaceClient + +from brickflow.context import ctx from brickflow.hints import propagate_hint @@ -28,3 +32,46 @@ def _property_iter() -> Iterator[str]: yield k return list(_property_iter()) + + +def get_job_id( + job_name: str, host: Union[str, None] = None, token: Union[str, SecretStr] = None +) -> Union[str, None]: + """ + Get the job id from the specified Databricks workspace for a given job name. + + Parameters + ---------- + job_name: str + Job name (case-insensitive) + host: str + Databricks workspace URL + token: str + Databricks API token + + Returns + ------- + str + Databricks job id + """ + ctx.log.info("Searching job id for job name: %s", job_name) + + if host: + host = host.rstrip("/") + token = token.get_secret_value() if isinstance(token, SecretStr) else token + + workspace_obj = WorkspaceClient(host=host, token=token) + jobs_list = workspace_obj.jobs.list(name=job_name) + + try: + for job in jobs_list: + ctx.log.info("Job id for job '%s' is %s", job_name, job.job_id) + return job.job_id + else: # pylint: disable=useless-else-on-loop + raise ValueError + except ValueError: + raise ValueError(f"No job found with name {job_name}") + except Exception as e: + ctx.log.info("An error occurred: %s", e) + + return None diff --git a/brickflow_plugins/__init__.py b/brickflow_plugins/__init__.py index 8dfbe37a..9151feda 100644 --- a/brickflow_plugins/__init__.py +++ b/brickflow_plugins/__init__.py @@ -33,6 +33,7 @@ def setup_logger(): ) from brickflow_plugins.databricks.workflow_dependency_sensor import ( WorkflowDependencySensor, + WorkflowTaskDependencySensor, ) from brickflow_plugins.databricks.uc_to_snowflake_operator import ( SnowflakeOperator, @@ -69,6 +70,7 @@ def ensure_installation(): "BranchPythonOperator", "ShortCircuitOperator", "WorkflowDependencySensor", + "WorkflowTaskDependencySensor", "SnowflakeOperator", "UcToSnowflakeOperator", "TableauRefreshDataSourceOperator", diff --git a/brickflow_plugins/databricks/workflow_dependency_sensor.py b/brickflow_plugins/databricks/workflow_dependency_sensor.py index 814f81af..e85047ba 100644 --- a/brickflow_plugins/databricks/workflow_dependency_sensor.py +++ b/brickflow_plugins/databricks/workflow_dependency_sensor.py @@ -1,15 +1,18 @@ import functools import logging import os +import time from datetime import datetime, timedelta from typing import Union +from warnings import warn import requests -import time from pydantic import SecretStr from requests.adapters import HTTPAdapter +from databricks.sdk import WorkspaceClient from brickflow.context import ctx +from brickflow.engine.utils import get_job_id class WorkflowDependencySensorException(Exception): @@ -43,13 +46,13 @@ def __init__( self, databricks_host: str, databricks_token: Union[str, SecretStr], - dependency_job_id: int, delta: timedelta, timeout_seconds: int, + dependency_job_id: int = None, + dependency_job_name: str = None, poke_interval_seconds: int = 60, ): self.databricks_host = databricks_host - self.dependency_job_id = dependency_job_id self.databricks_token = ( databricks_token if isinstance(databricks_token, SecretStr) @@ -61,6 +64,25 @@ def __init__( self.log = logging self.start_time = time.time() + if dependency_job_id: + warn( + "Please use 'dependency_job_name' instead of 'dependency_job_id'", + DeprecationWarning, + stacklevel=2, + ) + + if not dependency_job_id and not dependency_job_name: + raise WorkflowDependencySensorException( + "Either dependency_job_id or dependency_job_name should be provided" + ) + + self.dependency_job_id = dependency_job_id + self.dependency_job_name = dependency_job_name + + self._workspace_obj = WorkspaceClient( + host=self.databricks_host, token=self.databricks_token.get_secret_value() + ) + def get_retry_class(self, max_retries): from urllib3 import Retry @@ -94,24 +116,18 @@ def get_http_session(self): session.mount("http://", HTTPAdapter(max_retries=retries)) return session - def get_execution_start_time_unix_miliseconds(self) -> int: - session = self.get_http_session() - url = f"{self.databricks_host.rstrip('/')}/api/2.1/jobs/runs/get" - headers = { - "Authorization": f"Bearer {self.databricks_token.get_secret_value()}", - "Content-Type": "application/json", - } + def get_execution_start_time_unix_milliseconds(self) -> int: run_id = ctx.dbutils_widget_get_or_else("brickflow_parent_run_id", None) if run_id is None: raise WorkflowDependencySensorException( "run_id is empty, brickflow_parent_run_id parameter is not found " "or no value present" ) - params = {"run_id": run_id} - resp = session.get(url, params=params, headers=headers).json() - # Convert Unix timestamp in miliseconds to datetime object to easily incorporate the delta - start_time = datetime.fromtimestamp(resp["start_time"] / 1000) + run = self._workspace_obj.jobs.get_run(run_id=run_id) + + # Convert Unix timestamp in milliseconds to datetime object to easily incorporate the delta + start_time = datetime.fromtimestamp(run.start_time / 1000) execution_start_time = start_time - self.delta # Convert datetime object back to Unix timestamp in miliseconds @@ -128,7 +144,18 @@ def get_execution_start_time_unix_miliseconds(self) -> int: ) return execution_start_time_unix_miliseconds + @property + def _get_job_id(self): + return get_job_id( + host=self.databricks_host, + token=self.databricks_token, + job_name=self.dependency_job_name, + ) + def execute(self): + if not self.dependency_job_id: + self.dependency_job_id = self._get_job_id + session = self.get_http_session() url = f"{self.databricks_host.rstrip('/')}/api/2.1/jobs/runs/list" headers = { @@ -138,7 +165,7 @@ def execute(self): params = { "limit": 25, "job_id": self.dependency_job_id, - "start_time_from": self.get_execution_start_time_unix_miliseconds(), + "start_time_from": self.get_execution_start_time_unix_milliseconds(), } while True: @@ -171,3 +198,81 @@ def execute(self): self.log.info(f"sleeping for: {self.poke_interval}") time.sleep(self.poke_interval) + + +class WorkflowTaskDependencySensor(WorkflowDependencySensor): + """ + This is used to have dependencies on the specific task within a databricks workflow + + Example Usage in your brickflow task: + service_principle_pat = ctx.dbutils.secrets.get("scope", "service_principle_id") + WorkflowDependencySensor( + databricks_host=https://your_workspace_url.cloud.databricks.com, + databricks_token=service_principle_pat, + dependency_job_id=job_id, + dependency_task_name="foo", + poke_interval=20, + timeout=60, + delta=timedelta(days=1) + ) + In the above snippet Databricks secrets are used as a secure service to store the databricks token. + If you get your token from another secret management service, like AWS Secrets Manager, GCP Secret Manager + or Azure Key Vault, just pass it in the databricks_token argument. + """ + + def __init__( + self, + dependency_job_name: str, + dependency_task_name: str, + delta: timedelta, + timeout_seconds: int, + databricks_host: str = None, + databricks_token: Union[str, SecretStr] = None, + poke_interval_seconds: int = 60, + ): + super().__init__( + databricks_host=databricks_host, + databricks_token=databricks_token, + dependency_job_name=dependency_job_name, + delta=delta, + timeout_seconds=timeout_seconds, + poke_interval_seconds=poke_interval_seconds, + ) + + self.dependency_task_name = dependency_task_name + + def execute(self): + self.dependency_job_id = self._get_job_id + + while True: + runs_list = self._workspace_obj.jobs.list_runs( + job_id=self.dependency_job_id, + limit=25, + start_time_from=self.get_execution_start_time_unix_milliseconds(), + expand_tasks=True, + ) + + for run in runs_list: + for task in run.tasks: + if task.task_key == self.dependency_task_name: + task_state = task.state.result_state + self.log.info( + f"Found the run_id '{run.run_id}' and '{self.dependency_task_name}' " + f"task with state: {task_state.value}" + ) + if task_state.value == "SUCCESS": + self.log.info(f"Found a successful run: {run.run_id}") + return + + self.log.info("Didn't find a successful task run yet...") + + if ( + self.timeout is not None + and (time.time() - self.start_time) > self.timeout + ): + raise WorkflowDependencySensorTimeOutException( + f"The job has timed out..." + ) + + self.log.info(f"Sleeping for: {self.poke_interval}") + time.sleep(self.poke_interval) diff --git a/docs/api/workflow_dependency_sensor.md b/docs/api/workflow_dependency_sensor.md index e5ab74fd..178cd32a 100644 --- a/docs/api/workflow_dependency_sensor.md +++ b/docs/api/workflow_dependency_sensor.md @@ -6,6 +6,9 @@ search: ::: brickflow_plugins.databricks.workflow_dependency_sensor handler: python options: + members: + - WorkflowDependencySensor + - WorkflowTaskDependencySensor filters: - "!^_[^_]" - "!^__[^__]" diff --git a/docs/faq/faq.md b/docs/faq/faq.md index 34a87692..02d4d40c 100644 --- a/docs/faq/faq.md +++ b/docs/faq/faq.md @@ -59,6 +59,31 @@ def wait_on_workflow(*args): ) sensor.execute() ``` + +## How do I wait for a specific task in a workflow to finish before kicking off my own workflow's tasks? + +```python +from brickflow.context import ctx +from brickflow_plugins import WorkflowTaskDependencySensor + +wf = Workflow(...) + + +@wf.task +def wait_on_workflow(*args): + api_token_key = ctx.dbutils.secrets.get("scope", "api_token_key") + sensor = WorkflowTaskDependencySensor( + databricks_host="https://your_workspace_url.cloud.databricks.com", + databricks_token=api_token_key, + dependency_job_id=job_id, + dependency_task_name="foo", + poke_interval=20, + timeout=60, + delta=timedelta(days=1) + ) + sensor.execute() +``` + ## How do I run a sql query on snowflake from DBX? ```python from brickflow_plugins import SnowflakeOperator diff --git a/docs/tasks.md b/docs/tasks.md index 93c2fc68..f8562b25 100644 --- a/docs/tasks.md +++ b/docs/tasks.md @@ -397,6 +397,32 @@ def wait_on_workflow(*args): sensor.execute() ``` +#### Workflow Task Dependency Sensor + +Wait for a specific task in a workflow to finish before kicking off the current workflow's tasks + +```python title="workflow_dependency_sensor" +from brickflow.context import ctx +from brickflow_plugins import WorkflowTaskDependencySensor + +wf = Workflow(...) + + +@wf.task +def wait_on_workflow(*args): + api_token_key = ctx.dbutils.secrets.get("scope", "api_token_key") + sensor = WorkflowTaskDependencySensor( + databricks_host="https://your_workspace_url.cloud.databricks.com", + databricks_token=api_token_key, + dependency_job_id=job_id, + dependency_task_name="foo", + poke_interval=20, + timeout=60, + delta=timedelta(days=1) + ) + sensor.execute() +``` + #### Snowflake Operator run snowflake queries from the databricks environment diff --git a/poetry.lock b/poetry.lock index 2a178e19..929bb5ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4284,6 +4284,23 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-mock" +version = "1.12.1" +description = "Mock out responses from the requests package" +optional = false +python-versions = ">=3.5" +files = [ + {file = "requests-mock-1.12.1.tar.gz", hash = "sha256:e9e12e333b525156e82a3c852f22016b9158220d2f47454de9cae8a77d371401"}, + {file = "requests_mock-1.12.1-py2.py3-none-any.whl", hash = "sha256:b1e37054004cdd5e56c84454cc7df12b25f90f382159087f4b6915aaeef39563"}, +] + +[package.dependencies] +requests = ">=2.22,<3" + +[package.extras] +fixture = ["fixtures"] + [[package]] name = "requests-toolbelt" version = "1.0.0" @@ -5474,4 +5491,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "5ac4fca6a235d709a992a8f6a34db72fddc1b532cc1a8beee278e63c41988cb7" +content-hash = "b87bb4017bbdb0d96e0bf9f8475b7a262bd630d328011642ac07fc15885b7858" diff --git a/pyproject.toml b/pyproject.toml index 92ee243b..99e45881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ apache-airflow = "^2.7.3" snowflake = "^0.6.0" tableauserverclient = "^0.25" watchdog = "<4.0.0" +requests-mock = "1.12.1" [tool.poetry.group.docs.dependencies] mdx-include = ">=1.4.1,<2.0.0" diff --git a/tests/databricks_plugins/__init__.py b/tests/databricks_plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/databricks_plugins/test_workflow_task_dependency_sensor.py b/tests/databricks_plugins/test_workflow_task_dependency_sensor.py new file mode 100644 index 00000000..f870cc2f --- /dev/null +++ b/tests/databricks_plugins/test_workflow_task_dependency_sensor.py @@ -0,0 +1,103 @@ +from datetime import timedelta + +import pytest +from requests_mock.mocker import Mocker as RequestsMocker + +from brickflow_plugins.databricks.workflow_dependency_sensor import ( + WorkflowTaskDependencySensor, + WorkflowDependencySensorTimeOutException, +) + + +class TestWorkflowTaskDependencySensor: + workspace_url = "https://42.cloud.databricks.com" + endpoint_url = f"{workspace_url}/api/2.1/jobs/runs/list" + response = { + "runs": [ + { + "job_id": 1, + "run_id": 1, + "start_time": 1704063600000, + "state": { + "result_state": "SUCCESS", + }, + "tasks": [ + { + "run_id": 100, + "task_key": "foo", + "state": { + "result_state": "SUCCESS", + }, + }, + { + "run_id": 200, + "task_key": "bar", + "state": { + "result_state": "FAILED", + }, + }, + ], + } + ] + } + + @pytest.fixture(autouse=True) + def mock_get_execution_start_time_unix_milliseconds(self, mocker): + mocker.patch.object( + WorkflowTaskDependencySensor, + "get_execution_start_time_unix_milliseconds", + return_value=1704063600000, + ) + + @pytest.fixture(autouse=True) + def mock_get_job_id(self, mocker): + mocker.patch( + "brickflow_plugins.databricks.workflow_dependency_sensor.get_job_id", + return_value=1, + ) + + @pytest.fixture(autouse=True, name="api") + def mock_api(self): + rm = RequestsMocker() + rm.get(self.endpoint_url, json=self.response, status_code=int(200)) + yield rm + + def test_sensor_success(self, caplog, api): + with api: + sensor = WorkflowTaskDependencySensor( + databricks_host=self.workspace_url, + databricks_token="token", + dependency_job_name="job", + dependency_task_name="foo", + delta=timedelta(seconds=1), + timeout_seconds=1, + poke_interval_seconds=1, + ) + + sensor.execute() + + assert ( + "Found the run_id '1' and 'foo' task with state: SUCCESS" in caplog.text + ) + assert "Found a successful run: 1" in caplog.text + + def test_sensor_failure(self, caplog, api): + with api: + sensor = WorkflowTaskDependencySensor( + databricks_host=self.workspace_url, + databricks_token="token", + dependency_job_name="job", + dependency_task_name="bar", + delta=timedelta(seconds=1), + timeout_seconds=1, + poke_interval_seconds=1, + ) + + with pytest.raises(WorkflowDependencySensorTimeOutException): + sensor.execute() + assert "The job has timed out..." in caplog.text + assert "Didn't find a successful task run yet..." in caplog.text + assert ( + "Found the run_id '1' and 'bar' task with state: FAILED" + in caplog.text + ) diff --git a/tests/engine/test_utils.py b/tests/engine/test_utils.py new file mode 100644 index 00000000..ce8eb07d --- /dev/null +++ b/tests/engine/test_utils.py @@ -0,0 +1,55 @@ +import pytest +from requests_mock.mocker import Mocker as RequestsMocker + +from pydantic import SecretStr + +from brickflow.engine.utils import get_job_id, ctx + + +class TestUtils: + workspace_url = "https://42.cloud.databricks.com" + endpoint_url = f"{workspace_url}/api/2.1/jobs/list" + + ctx.log.propagate = True + + @pytest.fixture(autouse=True, name="api", scope="class") + def mock_api(self): + rm = RequestsMocker() + rm.register_uri( + method="GET", + url=self.endpoint_url, + response_list=[ + { + "json": {"jobs": [{"job_id": 1234, "settings": {"name": "foo"}}]}, + "status_code": int(200), + }, + { + "json": {"has_more": False}, + "status_code": int(200), + }, + { + "json": {}, + "status_code": int(404), + }, + ], + ) + yield rm + + def test_get_job_id_success(self, api): + with api: + job_id = get_job_id( + job_name="foo", + host=self.workspace_url, + token=SecretStr("token"), + ) + assert job_id == 1234 + + def test_get_job_id_failure(self, api): + with pytest.raises(ValueError): + with api: + get_job_id(job_name="bar", host=self.workspace_url, token="token") + + def test_get_job_id_non_200(self, caplog, api): + with api: + get_job_id(job_name="buz", host=self.workspace_url, token="token") + assert "An error occurred: request failed" in caplog.text