From e867dbc3daa9e946362c53aa4beebb8c0621faad Mon Sep 17 00:00:00 2001 From: Mayuri Nehate <33225191+mayurinehate@users.noreply.github.com> Date: Thu, 31 Aug 2023 02:38:42 +0530 Subject: [PATCH] ci: separate airflow build and test (#8688) Co-authored-by: Harshal Sheth --- .github/workflows/airflow-plugin.yml | 85 +++ .github/workflows/metadata-ingestion.yml | 7 +- .github/workflows/test-results.yml | 2 +- docs/lineage/airflow.md | 6 +- .../airflow-plugin/build.gradle | 59 +- .../airflow-plugin/pyproject.toml | 1 - .../airflow-plugin/setup.cfg | 4 +- .../airflow-plugin/setup.py | 24 +- .../datahub_airflow_plugin/_airflow_compat.py | 12 + .../datahub_airflow_plugin/_airflow_shims.py | 29 + .../datahub_airflow_plugin/_lineage_core.py | 115 ++++ .../client}/__init__.py | 0 .../client/airflow_generator.py | 512 ++++++++++++++++++ .../datahub_airflow_plugin/datahub_plugin.py | 371 ++++++++++++- .../src/datahub_airflow_plugin/entities.py | 47 ++ .../example_dags/.airflowignore | 0 .../example_dags/__init__.py | 0 .../example_dags/generic_recipe_sample_dag.py | 2 +- .../example_dags/lineage_backend_demo.py | 3 +- .../lineage_backend_taskflow_demo.py | 3 +- .../example_dags/lineage_emission_dag.py | 5 +- .../example_dags/mysql_sample_dag.py | 1 + .../example_dags/snowflake_sample_dag.py | 1 + .../datahub_airflow_plugin/hooks/__init__.py | 0 .../datahub_airflow_plugin/hooks/datahub.py | 214 ++++++++ .../lineage/__init__.py | 0 .../datahub_airflow_plugin/lineage/datahub.py | 91 ++++ .../operators/__init__.py | 0 .../operators/datahub.py | 63 +++ .../operators/datahub_assertion_operator.py | 78 +++ .../operators/datahub_assertion_sensor.py | 78 +++ .../operators/datahub_operation_operator.py | 97 ++++ .../operators/datahub_operation_sensor.py | 100 ++++ .../tests/unit/test_airflow.py | 16 +- metadata-ingestion/developing.md | 12 +- metadata-ingestion/schedule_docs/airflow.md | 6 +- metadata-ingestion/setup.cfg | 3 - metadata-ingestion/setup.py | 10 +- .../src/datahub_provider/__init__.py | 29 +- .../src/datahub_provider/_airflow_compat.py | 13 +- .../src/datahub_provider/_airflow_shims.py | 34 +- .../src/datahub_provider/_lineage_core.py | 115 +--- .../src/datahub_provider/_plugin.py | 369 +------------ .../client/airflow_generator.py | 510 +---------------- .../src/datahub_provider/entities.py | 49 +- .../src/datahub_provider/hooks/datahub.py | 220 +------- .../src/datahub_provider/lineage/datahub.py | 93 +--- .../src/datahub_provider/operators/datahub.py | 65 +-- .../operators/datahub_assertion_operator.py | 79 +-- .../operators/datahub_assertion_sensor.py | 79 +-- .../operators/datahub_operation_operator.py | 98 +--- .../operators/datahub_operation_sensor.py | 101 +--- 52 files changed, 2037 insertions(+), 1874 deletions(-) create mode 100644 .github/workflows/airflow-plugin.yml create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_compat.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_shims.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_lineage_core.py rename {metadata-ingestion/src/datahub_provider/example_dags => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client}/__init__.py (100%) create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/airflow_generator.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/entities.py rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/.airflowignore (100%) rename .github/workflows/docker-ingestion-base.yml => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/__init__.py (100%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/generic_recipe_sample_dag.py (98%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/lineage_backend_demo.py (94%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/lineage_backend_taskflow_demo.py (94%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/lineage_emission_dag.py (96%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/mysql_sample_dag.py (98%) rename {metadata-ingestion/src/datahub_provider => metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin}/example_dags/snowflake_sample_dag.py (99%) create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/__init__.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/__init__.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/datahub.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/__init__.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_operator.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_sensor.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_operator.py create mode 100644 metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_sensor.py rename {metadata-ingestion => metadata-ingestion-modules/airflow-plugin}/tests/unit/test_airflow.py (97%) diff --git a/.github/workflows/airflow-plugin.yml b/.github/workflows/airflow-plugin.yml new file mode 100644 index 0000000000000..63bab821cc398 --- /dev/null +++ b/.github/workflows/airflow-plugin.yml @@ -0,0 +1,85 @@ +name: Airflow Plugin +on: + push: + branches: + - master + paths: + - ".github/workflows/airflow-plugin.yml" + - "metadata-ingestion-modules/airflow-plugin/**" + - "metadata-ingestion/**" + - "metadata-models/**" + pull_request: + branches: + - master + paths: + - ".github/**" + - "metadata-ingestion-modules/airflow-plugin/**" + - "metadata-ingestion/**" + - "metadata-models/**" + release: + types: [published] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + airflow-plugin: + runs-on: ubuntu-latest + env: + SPARK_VERSION: 3.0.3 + DATAHUB_TELEMETRY_ENABLED: false + strategy: + matrix: + include: + - python-version: "3.7" + extraPythonRequirement: "apache-airflow~=2.1.0" + - python-version: "3.7" + extraPythonRequirement: "apache-airflow~=2.2.0" + - python-version: "3.10" + extraPythonRequirement: "apache-airflow~=2.4.0" + - python-version: "3.10" + extraPythonRequirement: "apache-airflow~=2.6.0" + - python-version: "3.10" + extraPythonRequirement: "apache-airflow>2.6.0" + fail-fast: false + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - name: Install dependencies + run: ./metadata-ingestion/scripts/install_deps.sh + - name: Install airflow package and test (extras ${{ matrix.extraPythonRequirement }}) + run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion-modules:airflow-plugin:lint :metadata-ingestion-modules:airflow-plugin:testQuick + - name: pip freeze show list installed + if: always() + run: source metadata-ingestion-modules/airflow-plugin/venv/bin/activate && pip freeze + - uses: actions/upload-artifact@v3 + if: ${{ always() && matrix.python-version == '3.10' && matrix.extraPythonRequirement == 'apache-airflow>2.6.0' }} + with: + name: Test Results (Airflow Plugin ${{ matrix.python-version}}) + path: | + **/build/reports/tests/test/** + **/build/test-results/test/** + **/junit.*.xml + - name: Upload coverage to Codecov + if: always() + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + directory: . + fail_ci_if_error: false + flags: airflow-${{ matrix.python-version }}-${{ matrix.extraPythonRequirement }} + name: pytest-airflow + verbose: true + + event-file: + runs-on: ubuntu-latest + steps: + - name: Upload + uses: actions/upload-artifact@v3 + with: + name: Event File + path: ${{ github.event_path }} diff --git a/.github/workflows/metadata-ingestion.yml b/.github/workflows/metadata-ingestion.yml index fb70c85fdec93..fff41e481c3cb 100644 --- a/.github/workflows/metadata-ingestion.yml +++ b/.github/workflows/metadata-ingestion.yml @@ -42,9 +42,7 @@ jobs: ] include: - python-version: "3.7" - extraPythonRequirement: "sqlalchemy==1.3.24 apache-airflow~=2.2.0" - python-version: "3.10" - extraPythonRequirement: "sqlalchemy~=1.4.0 apache-airflow>=2.4.0" fail-fast: false steps: - uses: actions/checkout@v3 @@ -56,8 +54,8 @@ jobs: run: ./metadata-ingestion/scripts/install_deps.sh - name: Install package run: ./gradlew :metadata-ingestion:installPackageOnly - - name: Run metadata-ingestion tests (extras ${{ matrix.extraPythonRequirement }}) - run: ./gradlew -Pextra_pip_requirements='${{ matrix.extraPythonRequirement }}' :metadata-ingestion:${{ matrix.command }} + - name: Run metadata-ingestion tests + run: ./gradlew :metadata-ingestion:${{ matrix.command }} - name: pip freeze show list installed if: always() run: source metadata-ingestion/venv/bin/activate && pip freeze @@ -80,7 +78,6 @@ jobs: name: pytest-${{ matrix.command }} verbose: true - event-file: runs-on: ubuntu-latest steps: diff --git a/.github/workflows/test-results.yml b/.github/workflows/test-results.yml index 656e4dcbc4e43..0153060692271 100644 --- a/.github/workflows/test-results.yml +++ b/.github/workflows/test-results.yml @@ -2,7 +2,7 @@ name: Test Results on: workflow_run: - workflows: ["build & test", "metadata ingestion"] + workflows: ["build & test", "metadata ingestion", "Airflow Plugin"] types: - completed diff --git a/docs/lineage/airflow.md b/docs/lineage/airflow.md index 21d59b777dd7c..49de5352f6d58 100644 --- a/docs/lineage/airflow.md +++ b/docs/lineage/airflow.md @@ -65,7 +65,7 @@ lazy_load_plugins = False | datahub.capture_executions | true | If true, we'll capture task runs in DataHub in addition to DAG definitions. | | datahub.graceful_exceptions | true | If set to true, most runtime errors in the lineage backend will be suppressed and will not cause the overall task to fail. Note that configuration issues will still throw exceptions. | -5. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html). +5. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html). 6. [optional] Learn more about [Airflow lineage](https://airflow.apache.org/docs/apache-airflow/stable/lineage.html), including shorthand notation and some automation. ### How to validate installation @@ -160,14 +160,14 @@ pip install acryl-datahub[airflow,datahub-kafka] - `capture_executions` (defaults to false): If true, it captures task runs as DataHub DataProcessInstances. - `graceful_exceptions` (defaults to true): If set to true, most runtime errors in the lineage backend will be suppressed and will not cause the overall task to fail. Note that configuration issues will still throw exceptions. -4. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html). +4. Configure `inlets` and `outlets` for your Airflow operators. For reference, look at the sample DAG in [`lineage_backend_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py), or reference [`lineage_backend_taskflow_demo.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py) if you're using the [TaskFlow API](https://airflow.apache.org/docs/apache-airflow/stable/concepts/taskflow.html). 5. [optional] Learn more about [Airflow lineage](https://airflow.apache.org/docs/apache-airflow/stable/lineage.html), including shorthand notation and some automation. ## Emitting lineage via a separate operator Take a look at this sample DAG: -- [`lineage_emission_dag.py`](../../metadata-ingestion/src/datahub_provider/example_dags/lineage_emission_dag.py) - emits lineage using the DatahubEmitterOperator. +- [`lineage_emission_dag.py`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_emission_dag.py) - emits lineage using the DatahubEmitterOperator. In order to use this example, you must first configure the Datahub hook. Like in ingestion, we support a Datahub REST hook and a Kafka-based hook. See step 1 above for details. diff --git a/metadata-ingestion-modules/airflow-plugin/build.gradle b/metadata-ingestion-modules/airflow-plugin/build.gradle index 336be8fc94d44..d1e6f2f646491 100644 --- a/metadata-ingestion-modules/airflow-plugin/build.gradle +++ b/metadata-ingestion-modules/airflow-plugin/build.gradle @@ -7,6 +7,10 @@ ext { venv_name = 'venv' } +if (!project.hasProperty("extra_pip_requirements")) { + ext.extra_pip_requirements = "" +} + def pip_install_command = "${venv_name}/bin/pip install -e ../../metadata-ingestion" task checkPythonVersion(type: Exec) { @@ -14,30 +18,37 @@ task checkPythonVersion(type: Exec) { } task environmentSetup(type: Exec, dependsOn: checkPythonVersion) { + def sentinel_file = "${venv_name}/.venv_environment_sentinel" inputs.file file('setup.py') - outputs.dir("${venv_name}") - commandLine 'bash', '-c', "${python_executable} -m venv ${venv_name} && ${venv_name}/bin/python -m pip install --upgrade pip wheel 'setuptools>=63.0.0'" + outputs.file(sentinel_file) + commandLine 'bash', '-c', + "${python_executable} -m venv ${venv_name} &&" + + "${venv_name}/bin/python -m pip install --upgrade pip wheel 'setuptools>=63.0.0' && " + + "touch ${sentinel_file}" } -task installPackage(type: Exec, dependsOn: environmentSetup) { +task installPackage(type: Exec, dependsOn: [environmentSetup, ':metadata-ingestion:codegen']) { + def sentinel_file = "${venv_name}/.build_install_package_sentinel" inputs.file file('setup.py') - outputs.dir("${venv_name}") + outputs.file(sentinel_file) // Workaround for https://github.com/yaml/pyyaml/issues/601. // See https://github.com/yaml/pyyaml/issues/601#issuecomment-1638509577. // and https://github.com/datahub-project/datahub/pull/8435. commandLine 'bash', '-x', '-c', "${pip_install_command} install 'Cython<3.0' 'PyYAML<6' --no-build-isolation && " + - "${pip_install_command} -e ." + "${pip_install_command} -e . ${extra_pip_requirements} &&" + + "touch ${sentinel_file}" } task install(dependsOn: [installPackage]) task installDev(type: Exec, dependsOn: [install]) { + def sentinel_file = "${venv_name}/.build_install_dev_sentinel" inputs.file file('setup.py') - outputs.dir("${venv_name}") - outputs.file("${venv_name}/.build_install_dev_sentinel") + outputs.file("${sentinel_file}") commandLine 'bash', '-x', '-c', - "${pip_install_command} -e .[dev] && touch ${venv_name}/.build_install_dev_sentinel" + "${pip_install_command} -e .[dev] ${extra_pip_requirements} && " + + "touch ${sentinel_file}" } task lint(type: Exec, dependsOn: installDev) { @@ -45,9 +56,13 @@ task lint(type: Exec, dependsOn: installDev) { The find/sed combo below is a temporary work-around for the following mypy issue with airflow 2.2.0: "venv/lib/python3.8/site-packages/airflow/_vendor/connexion/spec.py:169: error: invalid syntax". */ - commandLine 'bash', '-x', '-c', + commandLine 'bash', '-c', "find ${venv_name}/lib -path *airflow/_vendor/connexion/spec.py -exec sed -i.bak -e '169,169s/ # type: List\\[str\\]//g' {} \\; && " + - "source ${venv_name}/bin/activate && black --check --diff src/ tests/ && isort --check --diff src/ tests/ && flake8 --count --statistics src/ tests/ && mypy src/ tests/" + "source ${venv_name}/bin/activate && set -x && " + + "black --check --diff src/ tests/ && " + + "isort --check --diff src/ tests/ && " + + "flake8 --count --statistics src/ tests/ && " + + "mypy --show-traceback --show-error-codes src/ tests/" } task lintFix(type: Exec, dependsOn: installDev) { commandLine 'bash', '-x', '-c', @@ -58,21 +73,13 @@ task lintFix(type: Exec, dependsOn: installDev) { "mypy src/ tests/ " } -task testQuick(type: Exec, dependsOn: installDev) { - // We can't enforce the coverage requirements if we run a subset of the tests. - inputs.files(project.fileTree(dir: "src/", include: "**/*.py")) - inputs.files(project.fileTree(dir: "tests/")) - outputs.dir("${venv_name}") - commandLine 'bash', '-x', '-c', - "source ${venv_name}/bin/activate && pytest -vv --continue-on-collection-errors --junit-xml=junit.quick.xml" -} - task installDevTest(type: Exec, dependsOn: [installDev]) { + def sentinel_file = "${venv_name}/.build_install_dev_test_sentinel" inputs.file file('setup.py') outputs.dir("${venv_name}") - outputs.file("${venv_name}/.build_install_dev_test_sentinel") + outputs.file("${sentinel_file}") commandLine 'bash', '-x', '-c', - "${pip_install_command} -e .[dev,integration-tests] && touch ${venv_name}/.build_install_dev_test_sentinel" + "${pip_install_command} -e .[dev,integration-tests] && touch ${sentinel_file}" } def testFile = hasProperty('testFile') ? testFile : 'unknown' @@ -89,6 +96,16 @@ task testSingle(dependsOn: [installDevTest]) { } } +task testQuick(type: Exec, dependsOn: installDevTest) { + // We can't enforce the coverage requirements if we run a subset of the tests. + inputs.files(project.fileTree(dir: "src/", include: "**/*.py")) + inputs.files(project.fileTree(dir: "tests/")) + outputs.dir("${venv_name}") + commandLine 'bash', '-x', '-c', + "source ${venv_name}/bin/activate && pytest -vv --continue-on-collection-errors --junit-xml=junit.quick.xml" +} + + task testFull(type: Exec, dependsOn: [testQuick, installDevTest]) { commandLine 'bash', '-x', '-c', "source ${venv_name}/bin/activate && pytest -m 'not slow_integration' -vv --continue-on-collection-errors --junit-xml=junit.full.xml" diff --git a/metadata-ingestion-modules/airflow-plugin/pyproject.toml b/metadata-ingestion-modules/airflow-plugin/pyproject.toml index 83b79e3146176..fba81486b9f67 100644 --- a/metadata-ingestion-modules/airflow-plugin/pyproject.toml +++ b/metadata-ingestion-modules/airflow-plugin/pyproject.toml @@ -9,7 +9,6 @@ extend-exclude = ''' ^/tmp ''' include = '\.pyi?$' -target-version = ['py36', 'py37', 'py38'] [tool.isort] indent = ' ' diff --git a/metadata-ingestion-modules/airflow-plugin/setup.cfg b/metadata-ingestion-modules/airflow-plugin/setup.cfg index c9a2ba93e9933..157bcce1c298d 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.cfg +++ b/metadata-ingestion-modules/airflow-plugin/setup.cfg @@ -69,4 +69,6 @@ exclude_lines = pragma: no cover @abstract if TYPE_CHECKING: -#omit = +omit = + # omit example dags + src/datahub_airflow_plugin/example_dags/* diff --git a/metadata-ingestion-modules/airflow-plugin/setup.py b/metadata-ingestion-modules/airflow-plugin/setup.py index c2571916ca5d0..c5bdc7ea329cd 100644 --- a/metadata-ingestion-modules/airflow-plugin/setup.py +++ b/metadata-ingestion-modules/airflow-plugin/setup.py @@ -13,16 +13,21 @@ def get_long_description(): return pathlib.Path(os.path.join(root, "README.md")).read_text() +rest_common = {"requests", "requests_file"} + base_requirements = { # Compatibility. "dataclasses>=0.6; python_version < '3.7'", - "typing_extensions>=3.10.0.2", + # Typing extension should be >=3.10.0.2 ideally but we can't restrict due to Airflow 2.0.2 dependency conflict + "typing_extensions>=3.7.4.3 ; python_version < '3.8'", + "typing_extensions>=3.10.0.2,<4.6.0 ; python_version >= '3.8'", "mypy_extensions>=0.4.3", # Actual dependencies. "typing-inspect", "pydantic>=1.5.1", "apache-airflow >= 2.0.2", - f"acryl-datahub[airflow] == {package_metadata['__version__']}", + *rest_common, + f"acryl-datahub == {package_metadata['__version__']}", } @@ -47,19 +52,18 @@ def get_long_description(): base_dev_requirements = { *base_requirements, *mypy_stubs, - "black>=21.12b0", + "black==22.12.0", "coverage>=5.1", "flake8>=3.8.3", "flake8-tidy-imports>=4.3.0", "isort>=5.7.0", - "mypy>=0.920", + "mypy>=1.4.0", # pydantic 1.8.2 is incompatible with mypy 0.910. # See https://github.com/samuelcolvin/pydantic/pull/3175#issuecomment-995382910. - "pydantic>=1.9.0", + "pydantic>=1.10", "pytest>=6.2.2", "pytest-asyncio>=0.16.0", "pytest-cov>=2.8.1", - "pytest-docker>=0.10.3,<0.12", "tox", "deepdiff", "requests-mock", @@ -127,5 +131,13 @@ def get_long_description(): "datahub-kafka": [ f"acryl-datahub[datahub-kafka] == {package_metadata['__version__']}" ], + "integration-tests": [ + f"acryl-datahub[datahub-kafka] == {package_metadata['__version__']}", + # Extra requirements for Airflow. + "apache-airflow[snowflake]>=2.0.2", # snowflake is used in example dags + # Because of https://github.com/snowflakedb/snowflake-sqlalchemy/issues/350 we need to restrict SQLAlchemy's max version. + "SQLAlchemy<1.4.42", + "virtualenv", # needed by PythonVirtualenvOperator + ], }, ) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_compat.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_compat.py new file mode 100644 index 0000000000000..67c3348ec987c --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_compat.py @@ -0,0 +1,12 @@ +# This module must be imported before any Airflow imports in any of our files. +# The AIRFLOW_PATCHED just helps avoid flake8 errors. + +from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED + +assert MARKUPSAFE_PATCHED + +AIRFLOW_PATCHED = True + +__all__ = [ + "AIRFLOW_PATCHED", +] diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_shims.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_shims.py new file mode 100644 index 0000000000000..5ad20e1f72551 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_airflow_shims.py @@ -0,0 +1,29 @@ +from airflow.models.baseoperator import BaseOperator + +from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED + +try: + from airflow.models.mappedoperator import MappedOperator + from airflow.models.operator import Operator + from airflow.operators.empty import EmptyOperator +except ModuleNotFoundError: + # Operator isn't a real class, but rather a type alias defined + # as the union of BaseOperator and MappedOperator. + # Since older versions of Airflow don't have MappedOperator, we can just use BaseOperator. + Operator = BaseOperator # type: ignore + MappedOperator = None # type: ignore + from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore + +try: + from airflow.sensors.external_task import ExternalTaskSensor +except ImportError: + from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore + +assert AIRFLOW_PATCHED + +__all__ = [ + "Operator", + "MappedOperator", + "EmptyOperator", + "ExternalTaskSensor", +] diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_lineage_core.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_lineage_core.py new file mode 100644 index 0000000000000..d91c039ffa718 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/_lineage_core.py @@ -0,0 +1,115 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List + +import datahub.emitter.mce_builder as builder +from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult +from datahub.configuration.common import ConfigModel +from datahub.utilities.urns.dataset_urn import DatasetUrn + +from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator +from datahub_airflow_plugin.entities import _Entity + +if TYPE_CHECKING: + from airflow import DAG + from airflow.models.dagrun import DagRun + from airflow.models.taskinstance import TaskInstance + + from datahub_airflow_plugin._airflow_shims import Operator + from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook + + +def _entities_to_urn_list(iolets: List[_Entity]) -> List[DatasetUrn]: + return [DatasetUrn.create_from_string(let.urn) for let in iolets] + + +class DatahubBasicLineageConfig(ConfigModel): + enabled: bool = True + + # DataHub hook connection ID. + datahub_conn_id: str + + # Cluster to associate with the pipelines and tasks. Defaults to "prod". + cluster: str = builder.DEFAULT_FLOW_CLUSTER + + # If true, the owners field of the DAG will be capture as a DataHub corpuser. + capture_ownership_info: bool = True + + # If true, the tags field of the DAG will be captured as DataHub tags. + capture_tags_info: bool = True + + capture_executions: bool = False + + def make_emitter_hook(self) -> "DatahubGenericHook": + # This is necessary to avoid issues with circular imports. + from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook + + return DatahubGenericHook(self.datahub_conn_id) + + +def send_lineage_to_datahub( + config: DatahubBasicLineageConfig, + operator: "Operator", + inlets: List[_Entity], + outlets: List[_Entity], + context: Dict, +) -> None: + if not config.enabled: + return + + dag: "DAG" = context["dag"] + task: "Operator" = context["task"] + ti: "TaskInstance" = context["task_instance"] + + hook = config.make_emitter_hook() + emitter = hook.make_emitter() + + dataflow = AirflowGenerator.generate_dataflow( + cluster=config.cluster, + dag=dag, + capture_tags=config.capture_tags_info, + capture_owner=config.capture_ownership_info, + ) + dataflow.emit(emitter) + operator.log.info(f"Emitted from Lineage: {dataflow}") + + datajob = AirflowGenerator.generate_datajob( + cluster=config.cluster, + task=task, + dag=dag, + capture_tags=config.capture_tags_info, + capture_owner=config.capture_ownership_info, + ) + datajob.inlets.extend(_entities_to_urn_list(inlets)) + datajob.outlets.extend(_entities_to_urn_list(outlets)) + + datajob.emit(emitter) + operator.log.info(f"Emitted from Lineage: {datajob}") + + if config.capture_executions: + dag_run: "DagRun" = context["dag_run"] + + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=config.cluster, + ti=ti, + dag=dag, + dag_run=dag_run, + datajob=datajob, + emit_templates=False, + ) + + operator.log.info(f"Emitted from Lineage: {dpi}") + + dpi = AirflowGenerator.complete_datajob( + emitter=emitter, + cluster=config.cluster, + ti=ti, + dag=dag, + dag_run=dag_run, + datajob=datajob, + result=InstanceRunResult.SUCCESS, + end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000), + ) + operator.log.info(f"Emitted from Lineage: {dpi}") + + emitter.flush() diff --git a/metadata-ingestion/src/datahub_provider/example_dags/__init__.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/__init__.py similarity index 100% rename from metadata-ingestion/src/datahub_provider/example_dags/__init__.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/__init__.py diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/airflow_generator.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/airflow_generator.py new file mode 100644 index 0000000000000..b5e86e14d85d0 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/client/airflow_generator.py @@ -0,0 +1,512 @@ +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast + +from airflow.configuration import conf +from datahub.api.entities.datajob import DataFlow, DataJob +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, +) +from datahub.metadata.schema_classes import DataProcessTypeClass +from datahub.utilities.urns.data_flow_urn import DataFlowUrn +from datahub.utilities.urns.data_job_urn import DataJobUrn + +from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED + +assert AIRFLOW_PATCHED + +if TYPE_CHECKING: + from airflow import DAG + from airflow.models import DagRun, TaskInstance + from datahub.emitter.kafka_emitter import DatahubKafkaEmitter + from datahub.emitter.rest_emitter import DatahubRestEmitter + + from datahub_airflow_plugin._airflow_shims import Operator + + +def _task_downstream_task_ids(operator: "Operator") -> Set[str]: + if hasattr(operator, "downstream_task_ids"): + return operator.downstream_task_ids + return operator._downstream_task_id # type: ignore[attr-defined,union-attr] + + +class AirflowGenerator: + @staticmethod + def _get_dependencies( + task: "Operator", dag: "DAG", flow_urn: DataFlowUrn + ) -> List[DataJobUrn]: + from datahub_airflow_plugin._airflow_shims import ExternalTaskSensor + + # resolve URNs for upstream nodes in subdags upstream of the current task. + upstream_subdag_task_urns: List[DataJobUrn] = [] + + for upstream_task_id in task.upstream_task_ids: + upstream_task = dag.task_dict[upstream_task_id] + + # if upstream task is not a subdag, then skip it + upstream_subdag = getattr(upstream_task, "subdag", None) + if upstream_subdag is None: + continue + + # else, link the leaf tasks of the upstream subdag as upstream tasks + for upstream_subdag_task_id in upstream_subdag.task_dict: + upstream_subdag_task = upstream_subdag.task_dict[ + upstream_subdag_task_id + ] + + upstream_subdag_task_urn = DataJobUrn.create_from_ids( + job_id=upstream_subdag_task_id, data_flow_urn=str(flow_urn) + ) + + # if subdag task is a leaf task, then link it as an upstream task + if len(_task_downstream_task_ids(upstream_subdag_task)) == 0: + upstream_subdag_task_urns.append(upstream_subdag_task_urn) + + # resolve URNs for upstream nodes that trigger the subdag containing the current task. + # (if it is in a subdag at all) + upstream_subdag_triggers: List[DataJobUrn] = [] + + # subdags are always named with 'parent.child' style or Airflow won't run them + # add connection from subdag trigger(s) if subdag task has no upstreams + if ( + dag.is_subdag + and dag.parent_dag is not None + and len(task.upstream_task_ids) == 0 + ): + # filter through the parent dag's tasks and find the subdag trigger(s) + subdags = [ + x for x in dag.parent_dag.task_dict.values() if x.subdag is not None + ] + matched_subdags = [ + x for x in subdags if x.subdag and x.subdag.dag_id == dag.dag_id + ] + + # id of the task containing the subdag + subdag_task_id = matched_subdags[0].task_id + + # iterate through the parent dag's tasks and find the ones that trigger the subdag + for upstream_task_id in dag.parent_dag.task_dict: + upstream_task = dag.parent_dag.task_dict[upstream_task_id] + upstream_task_urn = DataJobUrn.create_from_ids( + data_flow_urn=str(flow_urn), job_id=upstream_task_id + ) + + # if the task triggers the subdag, link it to this node in the subdag + if subdag_task_id in _task_downstream_task_ids(upstream_task): + upstream_subdag_triggers.append(upstream_task_urn) + + # If the operator is an ExternalTaskSensor then we set the remote task as upstream. + # It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie + # jobflow to anothet jobflow. + external_task_upstreams = [] + if task.task_type == "ExternalTaskSensor": + task = cast(ExternalTaskSensor, task) + if hasattr(task, "external_task_id") and task.external_task_id is not None: + external_task_upstreams = [ + DataJobUrn.create_from_ids( + job_id=task.external_task_id, + data_flow_urn=str( + DataFlowUrn.create_from_ids( + orchestrator=flow_urn.get_orchestrator_name(), + flow_id=task.external_dag_id, + env=flow_urn.get_env(), + ) + ), + ) + ] + # exclude subdag operator tasks since these are not emitted, resulting in empty metadata + upstream_tasks = ( + [ + DataJobUrn.create_from_ids(job_id=task_id, data_flow_urn=str(flow_urn)) + for task_id in task.upstream_task_ids + if getattr(dag.task_dict[task_id], "subdag", None) is None + ] + + upstream_subdag_task_urns + + upstream_subdag_triggers + + external_task_upstreams + ) + return upstream_tasks + + @staticmethod + def generate_dataflow( + cluster: str, + dag: "DAG", + capture_owner: bool = True, + capture_tags: bool = True, + ) -> DataFlow: + """ + Generates a Dataflow object from an Airflow DAG + :param cluster: str - name of the cluster + :param dag: DAG - + :param capture_tags: + :param capture_owner: + :return: DataFlow - Data generated dataflow + """ + id = dag.dag_id + orchestrator = "airflow" + description = f"{dag.description}\n\n{dag.doc_md or ''}" + data_flow = DataFlow( + env=cluster, id=id, orchestrator=orchestrator, description=description + ) + + flow_property_bag: Dict[str, str] = {} + + allowed_flow_keys = [ + "_access_control", + "_concurrency", + "_default_view", + "catchup", + "fileloc", + "is_paused_upon_creation", + "start_date", + "tags", + "timezone", + ] + + for key in allowed_flow_keys: + if hasattr(dag, key): + flow_property_bag[key] = repr(getattr(dag, key)) + + data_flow.properties = flow_property_bag + base_url = conf.get("webserver", "base_url") + data_flow.url = f"{base_url}/tree?dag_id={dag.dag_id}" + + if capture_owner and dag.owner: + data_flow.owners.add(dag.owner) + + if capture_tags and dag.tags: + data_flow.tags.update(dag.tags) + + return data_flow + + @staticmethod + def _get_description(task: "Operator") -> Optional[str]: + from airflow.models.baseoperator import BaseOperator + + if not isinstance(task, BaseOperator): + # TODO: Get docs for mapped operators. + return None + + if hasattr(task, "doc") and task.doc: + return task.doc + elif hasattr(task, "doc_md") and task.doc_md: + return task.doc_md + elif hasattr(task, "doc_json") and task.doc_json: + return task.doc_json + elif hasattr(task, "doc_yaml") and task.doc_yaml: + return task.doc_yaml + elif hasattr(task, "doc_rst") and task.doc_yaml: + return task.doc_yaml + return None + + @staticmethod + def generate_datajob( + cluster: str, + task: "Operator", + dag: "DAG", + set_dependencies: bool = True, + capture_owner: bool = True, + capture_tags: bool = True, + ) -> DataJob: + """ + + :param cluster: str + :param task: TaskIntance + :param dag: DAG + :param set_dependencies: bool - whether to extract dependencies from airflow task + :param capture_owner: bool - whether to extract owner from airflow task + :param capture_tags: bool - whether to set tags automatically from airflow task + :return: DataJob - returns the generated DataJob object + """ + dataflow_urn = DataFlowUrn.create_from_ids( + orchestrator="airflow", env=cluster, flow_id=dag.dag_id + ) + datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn) + + # TODO add support for MappedOperator + datajob.description = AirflowGenerator._get_description(task) + + job_property_bag: Dict[str, str] = {} + + allowed_task_keys = [ + "_downstream_task_ids", + "_inlets", + "_outlets", + "_task_type", + "_task_module", + "depends_on_past", + "email", + "label", + "execution_timeout", + "sla", + "sql", + "task_id", + "trigger_rule", + "wait_for_downstream", + # In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids + "downstream_task_ids", + # In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions. + "inlets", + "outlets", + ] + + for key in allowed_task_keys: + if hasattr(task, key): + job_property_bag[key] = repr(getattr(task, key)) + + datajob.properties = job_property_bag + base_url = conf.get("webserver", "base_url") + datajob.url = f"{base_url}/taskinstance/list/?flt1_dag_id_equals={datajob.flow_urn.get_flow_id()}&_flt_3_task_id={task.task_id}" + + if capture_owner and dag.owner: + datajob.owners.add(dag.owner) + + if capture_tags and dag.tags: + datajob.tags.update(dag.tags) + + if set_dependencies: + datajob.upstream_urns.extend( + AirflowGenerator._get_dependencies( + task=task, dag=dag, flow_urn=datajob.flow_urn + ) + ) + + return datajob + + @staticmethod + def create_datajob_instance( + cluster: str, + task: "Operator", + dag: "DAG", + data_job: Optional[DataJob] = None, + ) -> DataProcessInstance: + if data_job is None: + data_job = AirflowGenerator.generate_datajob(cluster, task=task, dag=dag) + dpi = DataProcessInstance.from_datajob( + datajob=data_job, id=task.task_id, clone_inlets=True, clone_outlets=True + ) + return dpi + + @staticmethod + def run_dataflow( + emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], + cluster: str, + dag_run: "DagRun", + start_timestamp_millis: Optional[int] = None, + dataflow: Optional[DataFlow] = None, + ) -> None: + if dataflow is None: + assert dag_run.dag + dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) + + if start_timestamp_millis is None: + assert dag_run.execution_date + start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000) + + assert dag_run.run_id + dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) + + # This property only exists in Airflow2 + if hasattr(dag_run, "run_type"): + from airflow.utils.types import DagRunType + + if dag_run.run_type == DagRunType.SCHEDULED: + dpi.type = DataProcessTypeClass.BATCH_SCHEDULED + elif dag_run.run_type == DagRunType.MANUAL: + dpi.type = DataProcessTypeClass.BATCH_AD_HOC + else: + if dag_run.run_id.startswith("scheduled__"): + dpi.type = DataProcessTypeClass.BATCH_SCHEDULED + else: + dpi.type = DataProcessTypeClass.BATCH_AD_HOC + + property_bag: Dict[str, str] = {} + property_bag["run_id"] = str(dag_run.run_id) + property_bag["execution_date"] = str(dag_run.execution_date) + property_bag["end_date"] = str(dag_run.end_date) + property_bag["start_date"] = str(dag_run.start_date) + property_bag["creating_job_id"] = str(dag_run.creating_job_id) + # These properties only exists in Airflow>=2.2.0 + if hasattr(dag_run, "data_interval_start") and hasattr( + dag_run, "data_interval_end" + ): + property_bag["data_interval_start"] = str(dag_run.data_interval_start) + property_bag["data_interval_end"] = str(dag_run.data_interval_end) + property_bag["external_trigger"] = str(dag_run.external_trigger) + dpi.properties.update(property_bag) + + dpi.emit_process_start( + emitter=emitter, start_timestamp_millis=start_timestamp_millis + ) + + @staticmethod + def complete_dataflow( + emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], + cluster: str, + dag_run: "DagRun", + end_timestamp_millis: Optional[int] = None, + dataflow: Optional[DataFlow] = None, + ) -> None: + """ + + :param emitter: DatahubRestEmitter - the datahub rest emitter to emit the generated mcps + :param cluster: str - name of the cluster + :param dag_run: DagRun + :param end_timestamp_millis: Optional[int] - the completion time in milliseconds if not set the current time will be used. + :param dataflow: Optional[Dataflow] + """ + if dataflow is None: + assert dag_run.dag + dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) + + assert dag_run.run_id + dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) + if end_timestamp_millis is None: + if dag_run.end_date is None: + raise Exception( + f"Dag {dag_run.dag_id}_{dag_run.run_id} is still running and unable to get end_date..." + ) + end_timestamp_millis = int(dag_run.end_date.timestamp() * 1000) + + # We should use DagRunState but it is not available in Airflow 1 + if dag_run.state == "success": + result = InstanceRunResult.SUCCESS + elif dag_run.state == "failed": + result = InstanceRunResult.FAILURE + else: + raise Exception( + f"Result should be either success or failure and it was {dag_run.state}" + ) + + dpi.emit_process_end( + emitter=emitter, + end_timestamp_millis=end_timestamp_millis, + result=result, + result_type="airflow", + ) + + @staticmethod + def run_datajob( + emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], + cluster: str, + ti: "TaskInstance", + dag: "DAG", + dag_run: "DagRun", + start_timestamp_millis: Optional[int] = None, + datajob: Optional[DataJob] = None, + attempt: Optional[int] = None, + emit_templates: bool = True, + ) -> DataProcessInstance: + if datajob is None: + datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag) + + assert dag_run.run_id + dpi = DataProcessInstance.from_datajob( + datajob=datajob, + id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}", + clone_inlets=True, + clone_outlets=True, + ) + job_property_bag: Dict[str, str] = {} + job_property_bag["run_id"] = str(dag_run.run_id) + job_property_bag["duration"] = str(ti.duration) + job_property_bag["start_date"] = str(ti.start_date) + job_property_bag["end_date"] = str(ti.end_date) + job_property_bag["execution_date"] = str(ti.execution_date) + job_property_bag["try_number"] = str(ti.try_number - 1) + job_property_bag["hostname"] = str(ti.hostname) + job_property_bag["max_tries"] = str(ti.max_tries) + # Not compatible with Airflow 1 + if hasattr(ti, "external_executor_id"): + job_property_bag["external_executor_id"] = str(ti.external_executor_id) + job_property_bag["pid"] = str(ti.pid) + job_property_bag["state"] = str(ti.state) + job_property_bag["operator"] = str(ti.operator) + job_property_bag["priority_weight"] = str(ti.priority_weight) + job_property_bag["unixname"] = str(ti.unixname) + job_property_bag["log_url"] = ti.log_url + dpi.properties.update(job_property_bag) + dpi.url = ti.log_url + + # This property only exists in Airflow2 + if hasattr(ti, "dag_run") and hasattr(ti.dag_run, "run_type"): + from airflow.utils.types import DagRunType + + if ti.dag_run.run_type == DagRunType.SCHEDULED: + dpi.type = DataProcessTypeClass.BATCH_SCHEDULED + elif ti.dag_run.run_type == DagRunType.MANUAL: + dpi.type = DataProcessTypeClass.BATCH_AD_HOC + else: + if dag_run.run_id.startswith("scheduled__"): + dpi.type = DataProcessTypeClass.BATCH_SCHEDULED + else: + dpi.type = DataProcessTypeClass.BATCH_AD_HOC + + if start_timestamp_millis is None: + assert ti.start_date + start_timestamp_millis = int(ti.start_date.timestamp() * 1000) + + if attempt is None: + attempt = ti.try_number + + dpi.emit_process_start( + emitter=emitter, + start_timestamp_millis=start_timestamp_millis, + attempt=attempt, + emit_template=emit_templates, + ) + return dpi + + @staticmethod + def complete_datajob( + emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], + cluster: str, + ti: "TaskInstance", + dag: "DAG", + dag_run: "DagRun", + end_timestamp_millis: Optional[int] = None, + result: Optional[InstanceRunResult] = None, + datajob: Optional[DataJob] = None, + ) -> DataProcessInstance: + """ + + :param emitter: DatahubRestEmitter + :param cluster: str + :param ti: TaskInstance + :param dag: DAG + :param dag_run: DagRun + :param end_timestamp_millis: Optional[int] + :param result: Optional[str] One of the result from datahub.metadata.schema_class.RunResultTypeClass + :param datajob: Optional[DataJob] + :return: DataProcessInstance + """ + if datajob is None: + datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag) + + if end_timestamp_millis is None: + assert ti.end_date + end_timestamp_millis = int(ti.end_date.timestamp() * 1000) + + if result is None: + # We should use TaskInstanceState but it is not available in Airflow 1 + if ti.state == "success": + result = InstanceRunResult.SUCCESS + elif ti.state == "failed": + result = InstanceRunResult.FAILURE + else: + raise Exception( + f"Result should be either success or failure and it was {ti.state}" + ) + + dpi = DataProcessInstance.from_datajob( + datajob=datajob, + id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}", + clone_inlets=True, + clone_outlets=True, + ) + dpi.emit_process_end( + emitter=emitter, + end_timestamp_millis=end_timestamp_millis, + result=result, + result_type="airflow", + ) + return dpi diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py index 226a7382f7595..d1cec9e5c1b54 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_plugin.py @@ -1,4 +1,367 @@ -# This package serves as a shim, but the actual implementation lives in datahub_provider -# from the acryl-datahub package. We leave this shim here to avoid breaking existing -# Airflow installs. -from datahub_provider._plugin import DatahubPlugin # noqa: F401 +import contextlib +import logging +import traceback +from typing import Any, Callable, Iterable, List, Optional, Union + +from airflow.configuration import conf +from airflow.lineage import PIPELINE_OUTLETS +from airflow.models.baseoperator import BaseOperator +from airflow.plugins_manager import AirflowPlugin +from airflow.utils.module_loading import import_string +from cattr import structure +from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult + +from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED +from datahub_airflow_plugin._airflow_shims import MappedOperator, Operator +from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator +from datahub_airflow_plugin.hooks.datahub import DatahubGenericHook +from datahub_airflow_plugin.lineage.datahub import DatahubLineageConfig + +assert AIRFLOW_PATCHED +logger = logging.getLogger(__name__) + +TASK_ON_FAILURE_CALLBACK = "on_failure_callback" +TASK_ON_SUCCESS_CALLBACK = "on_success_callback" + + +def get_lineage_config() -> DatahubLineageConfig: + """Load the lineage config from airflow.cfg.""" + + enabled = conf.get("datahub", "enabled", fallback=True) + datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default") + cluster = conf.get("datahub", "cluster", fallback="prod") + graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True) + capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True) + capture_ownership_info = conf.get( + "datahub", "capture_ownership_info", fallback=True + ) + capture_executions = conf.get("datahub", "capture_executions", fallback=True) + return DatahubLineageConfig( + enabled=enabled, + datahub_conn_id=datahub_conn_id, + cluster=cluster, + graceful_exceptions=graceful_exceptions, + capture_ownership_info=capture_ownership_info, + capture_tags_info=capture_tags_info, + capture_executions=capture_executions, + ) + + +def _task_inlets(operator: "Operator") -> List: + # From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets + if hasattr(operator, "_inlets"): + return operator._inlets # type: ignore[attr-defined, union-attr] + return operator.inlets + + +def _task_outlets(operator: "Operator") -> List: + # From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets + # We have to use _outlets because outlets is empty in Airflow < 2.4.0 + if hasattr(operator, "_outlets"): + return operator._outlets # type: ignore[attr-defined, union-attr] + return operator.outlets + + +def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: + # TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae + # in Airflow 2.4. + # TODO: ignore/handle airflow's dataset type in our lineage + + inlets: List[Any] = [] + task_inlets = _task_inlets(task) + # From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator + if isinstance(task_inlets, (str, BaseOperator)): + inlets = [ + task_inlets, + ] + + if task_inlets and isinstance(task_inlets, list): + inlets = [] + task_ids = ( + {o for o in task_inlets if isinstance(o, str)} + .union(op.task_id for op in task_inlets if isinstance(op, BaseOperator)) + .intersection(task.get_flat_relative_ids(upstream=True)) + ) + + from airflow.lineage import AUTO + + # pick up unique direct upstream task_ids if AUTO is specified + if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets: + print("Picking up unique direct upstream task_ids as AUTO is specified") + task_ids = task_ids.union( + task_ids.symmetric_difference(task.upstream_task_ids) + ) + + inlets = task.xcom_pull( + context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS + ) + + # re-instantiate the obtained inlets + inlets = [ + structure(item["data"], import_string(item["type_name"])) + # _get_instance(structure(item, Metadata)) + for sublist in inlets + if sublist + for item in sublist + ] + + for inlet in task_inlets: + if not isinstance(inlet, str): + inlets.append(inlet) + + return inlets + + +def _make_emit_callback( + logger: logging.Logger, +) -> Callable[[Optional[Exception], str], None]: + def emit_callback(err: Optional[Exception], msg: str) -> None: + if err: + logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err) + + return emit_callback + + +def datahub_task_status_callback(context, status): + ti = context["ti"] + task: "BaseOperator" = ti.task + dag = context["dag"] + + # This code is from the original airflow lineage code -> + # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py + inlets = get_inlets_from_task(task, context) + + emitter = ( + DatahubGenericHook(context["_datahub_config"].datahub_conn_id) + .get_underlying_hook() + .make_emitter() + ) + + dataflow = AirflowGenerator.generate_dataflow( + cluster=context["_datahub_config"].cluster, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + task.log.info(f"Emitting Datahub Dataflow: {dataflow}") + dataflow.emit(emitter, callback=_make_emit_callback(task.log)) + + datajob = AirflowGenerator.generate_datajob( + cluster=context["_datahub_config"].cluster, + task=task, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + + for inlet in inlets: + datajob.inlets.append(inlet.urn) + + task_outlets = _task_outlets(task) + for outlet in task_outlets: + datajob.outlets.append(outlet.urn) + + task.log.info(f"Emitting Datahub Datajob: {datajob}") + datajob.emit(emitter, callback=_make_emit_callback(task.log)) + + if context["_datahub_config"].capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag=dag, + dag_run=context["dag_run"], + datajob=datajob, + start_timestamp_millis=int(ti.start_date.timestamp() * 1000), + ) + + task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") + + dpi = AirflowGenerator.complete_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag_run=context["dag_run"], + result=status, + dag=dag, + datajob=datajob, + end_timestamp_millis=int(ti.end_date.timestamp() * 1000), + ) + task.log.info(f"Emitted Completed Data Process Instance: {dpi}") + + emitter.flush() + + +def datahub_pre_execution(context): + ti = context["ti"] + task: "BaseOperator" = ti.task + dag = context["dag"] + + task.log.info("Running Datahub pre_execute method") + + emitter = ( + DatahubGenericHook(context["_datahub_config"].datahub_conn_id) + .get_underlying_hook() + .make_emitter() + ) + + # This code is from the original airflow lineage code -> + # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py + inlets = get_inlets_from_task(task, context) + + datajob = AirflowGenerator.generate_datajob( + cluster=context["_datahub_config"].cluster, + task=context["ti"].task, + dag=dag, + capture_tags=context["_datahub_config"].capture_tags_info, + capture_owner=context["_datahub_config"].capture_ownership_info, + ) + + for inlet in inlets: + datajob.inlets.append(inlet.urn) + + task_outlets = _task_outlets(task) + + for outlet in task_outlets: + datajob.outlets.append(outlet.urn) + + task.log.info(f"Emitting Datahub dataJob {datajob}") + datajob.emit(emitter, callback=_make_emit_callback(task.log)) + + if context["_datahub_config"].capture_executions: + dpi = AirflowGenerator.run_datajob( + emitter=emitter, + cluster=context["_datahub_config"].cluster, + ti=context["ti"], + dag=dag, + dag_run=context["dag_run"], + datajob=datajob, + start_timestamp_millis=int(ti.start_date.timestamp() * 1000), + ) + + task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}") + + emitter.flush() + + +def _wrap_pre_execution(pre_execution): + def custom_pre_execution(context): + config = get_lineage_config() + if config.enabled: + context["_datahub_config"] = config + datahub_pre_execution(context) + + # Call original policy + if pre_execution: + pre_execution(context) + + return custom_pre_execution + + +def _wrap_on_failure_callback(on_failure_callback): + def custom_on_failure_callback(context): + config = get_lineage_config() + if config.enabled: + context["_datahub_config"] = config + try: + datahub_task_status_callback(context, status=InstanceRunResult.FAILURE) + except Exception as e: + if not config.graceful_exceptions: + raise e + else: + print(f"Exception: {traceback.format_exc()}") + + # Call original policy + if on_failure_callback: + on_failure_callback(context) + + return custom_on_failure_callback + + +def _wrap_on_success_callback(on_success_callback): + def custom_on_success_callback(context): + config = get_lineage_config() + if config.enabled: + context["_datahub_config"] = config + try: + datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS) + except Exception as e: + if not config.graceful_exceptions: + raise e + else: + print(f"Exception: {traceback.format_exc()}") + + # Call original policy + if on_success_callback: + on_success_callback(context) + + return custom_on_success_callback + + +def task_policy(task: Union[BaseOperator, MappedOperator]) -> None: + task.log.debug(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}") + # task.add_inlets(["auto"]) + # task.pre_execute = _wrap_pre_execution(task.pre_execute) + + # MappedOperator's callbacks don't have setters until Airflow 2.X.X + # https://github.com/apache/airflow/issues/24547 + # We can bypass this by going through partial_kwargs for now + if MappedOperator and isinstance(task, MappedOperator): # type: ignore + on_failure_callback_prop: property = getattr( + MappedOperator, TASK_ON_FAILURE_CALLBACK + ) + on_success_callback_prop: property = getattr( + MappedOperator, TASK_ON_SUCCESS_CALLBACK + ) + if not on_failure_callback_prop.fset or not on_success_callback_prop.fset: + task.log.debug( + "Using MappedOperator's partial_kwargs instead of callback properties" + ) + task.partial_kwargs[TASK_ON_FAILURE_CALLBACK] = _wrap_on_failure_callback( + task.on_failure_callback + ) + task.partial_kwargs[TASK_ON_SUCCESS_CALLBACK] = _wrap_on_success_callback( + task.on_success_callback + ) + return + + task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) # type: ignore + task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) # type: ignore + # task.pre_execute = _wrap_pre_execution(task.pre_execute) + + +def _wrap_task_policy(policy): + if policy and hasattr(policy, "_task_policy_patched_by"): + return policy + + def custom_task_policy(task): + policy(task) + task_policy(task) + + # Add a flag to the policy to indicate that we've patched it. + custom_task_policy._task_policy_patched_by = "datahub_plugin" # type: ignore[attr-defined] + return custom_task_policy + + +def _patch_policy(settings): + if hasattr(settings, "task_policy"): + datahub_task_policy = _wrap_task_policy(settings.task_policy) + settings.task_policy = datahub_task_policy + + +def _patch_datahub_policy(): + with contextlib.suppress(ImportError): + import airflow_local_settings + + _patch_policy(airflow_local_settings) + + from airflow.models.dagbag import settings + + _patch_policy(settings) + + +_patch_datahub_policy() + + +class DatahubPlugin(AirflowPlugin): + name = "datahub_plugin" diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/entities.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/entities.py new file mode 100644 index 0000000000000..69f667cad3241 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/entities.py @@ -0,0 +1,47 @@ +from abc import abstractmethod +from typing import Optional + +import attr +import datahub.emitter.mce_builder as builder +from datahub.utilities.urns.urn import guess_entity_type + + +class _Entity: + @property + @abstractmethod + def urn(self) -> str: + pass + + +@attr.s(auto_attribs=True, str=True) +class Dataset(_Entity): + platform: str + name: str + env: str = builder.DEFAULT_ENV + platform_instance: Optional[str] = None + + @property + def urn(self): + return builder.make_dataset_urn_with_platform_instance( + platform=self.platform, + name=self.name, + platform_instance=self.platform_instance, + env=self.env, + ) + + +@attr.s(str=True) +class Urn(_Entity): + _urn: str = attr.ib() + + @_urn.validator + def _validate_urn(self, attribute, value): + if not value.startswith("urn:"): + raise ValueError("invalid urn provided: urns must start with 'urn:'") + if guess_entity_type(value) != "dataset": + # This is because DataJobs only support Dataset lineage. + raise ValueError("Airflow lineage currently only supports datasets") + + @property + def urn(self): + return self._urn diff --git a/metadata-ingestion/src/datahub_provider/example_dags/.airflowignore b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/.airflowignore similarity index 100% rename from metadata-ingestion/src/datahub_provider/example_dags/.airflowignore rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/.airflowignore diff --git a/.github/workflows/docker-ingestion-base.yml b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/__init__.py similarity index 100% rename from .github/workflows/docker-ingestion-base.yml rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/__init__.py diff --git a/metadata-ingestion/src/datahub_provider/example_dags/generic_recipe_sample_dag.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/generic_recipe_sample_dag.py similarity index 98% rename from metadata-ingestion/src/datahub_provider/example_dags/generic_recipe_sample_dag.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/generic_recipe_sample_dag.py index d0e4aa944e840..ff8dba457066f 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/generic_recipe_sample_dag.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/generic_recipe_sample_dag.py @@ -9,7 +9,6 @@ from airflow import DAG from airflow.operators.python import PythonOperator from airflow.utils.dates import days_ago - from datahub.configuration.config_loader import load_config_file from datahub.ingestion.run.pipeline import Pipeline @@ -41,6 +40,7 @@ def datahub_recipe(): schedule_interval=timedelta(days=1), start_date=days_ago(2), catchup=False, + default_view="tree", ) as dag: ingest_task = PythonOperator( task_id="ingest_using_recipe", diff --git a/metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py similarity index 94% rename from metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py index 95b594e4052a5..3caea093b932d 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_demo.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_demo.py @@ -9,7 +9,7 @@ from airflow.operators.bash import BashOperator from airflow.utils.dates import days_ago -from datahub_provider.entities import Dataset, Urn +from datahub_airflow_plugin.entities import Dataset, Urn default_args = { "owner": "airflow", @@ -28,6 +28,7 @@ start_date=days_ago(2), tags=["example_tag"], catchup=False, + default_view="tree", ) as dag: task1 = BashOperator( task_id="run_data_task", diff --git a/metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py similarity index 94% rename from metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py index 1fe321eb5c80a..ceb0f452b540a 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/lineage_backend_taskflow_demo.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_backend_taskflow_demo.py @@ -8,7 +8,7 @@ from airflow.decorators import dag, task from airflow.utils.dates import days_ago -from datahub_provider.entities import Dataset, Urn +from datahub_airflow_plugin.entities import Dataset, Urn default_args = { "owner": "airflow", @@ -26,6 +26,7 @@ start_date=days_ago(2), tags=["example_tag"], catchup=False, + default_view="tree", ) def datahub_lineage_backend_taskflow_demo(): @task( diff --git a/metadata-ingestion/src/datahub_provider/example_dags/lineage_emission_dag.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_emission_dag.py similarity index 96% rename from metadata-ingestion/src/datahub_provider/example_dags/lineage_emission_dag.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_emission_dag.py index 153464246cef7..f40295c6bb883 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/lineage_emission_dag.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/lineage_emission_dag.py @@ -5,12 +5,12 @@ from datetime import timedelta +import datahub.emitter.mce_builder as builder from airflow import DAG from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator from airflow.utils.dates import days_ago -import datahub.emitter.mce_builder as builder -from datahub_provider.operators.datahub import DatahubEmitterOperator +from datahub_airflow_plugin.operators.datahub import DatahubEmitterOperator default_args = { "owner": "airflow", @@ -31,6 +31,7 @@ schedule_interval=timedelta(days=1), start_date=days_ago(2), catchup=False, + default_view="tree", ) as dag: # This example shows a SnowflakeOperator followed by a lineage emission. However, the # same DatahubEmitterOperator can be used to emit lineage in any context. diff --git a/metadata-ingestion/src/datahub_provider/example_dags/mysql_sample_dag.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/mysql_sample_dag.py similarity index 98% rename from metadata-ingestion/src/datahub_provider/example_dags/mysql_sample_dag.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/mysql_sample_dag.py index 2c833e1425634..77b29711d7688 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/mysql_sample_dag.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/mysql_sample_dag.py @@ -47,6 +47,7 @@ def ingest_from_mysql(): start_date=datetime(2022, 1, 1), schedule_interval=timedelta(days=1), catchup=False, + default_view="tree", ) as dag: # While it is also possible to use the PythonOperator, we recommend using # the PythonVirtualenvOperator to ensure that there are no dependency diff --git a/metadata-ingestion/src/datahub_provider/example_dags/snowflake_sample_dag.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/snowflake_sample_dag.py similarity index 99% rename from metadata-ingestion/src/datahub_provider/example_dags/snowflake_sample_dag.py rename to metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/snowflake_sample_dag.py index c107bb479262c..30e63b68e459f 100644 --- a/metadata-ingestion/src/datahub_provider/example_dags/snowflake_sample_dag.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/snowflake_sample_dag.py @@ -57,6 +57,7 @@ def ingest_from_snowflake(snowflake_credentials, datahub_gms_server): start_date=datetime(2022, 1, 1), schedule_interval=timedelta(days=1), catchup=False, + default_view="tree", ) as dag: # This example pulls credentials from Airflow's connection store. # For this to work, you must have previously configured these connections in Airflow. diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/__init__.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py new file mode 100644 index 0000000000000..aed858c6c4df0 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/hooks/datahub.py @@ -0,0 +1,214 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( + MetadataChangeEvent, + MetadataChangeProposal, +) + +if TYPE_CHECKING: + from airflow.models.connection import Connection + from datahub.emitter.kafka_emitter import DatahubKafkaEmitter + from datahub.emitter.rest_emitter import DatahubRestEmitter + from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig + + +class DatahubRestHook(BaseHook): + """ + Creates a DataHub Rest API connection used to send metadata to DataHub. + Takes the endpoint for your DataHub Rest API in the Server Endpoint(host) field. + + URI example: :: + + AIRFLOW_CONN_DATAHUB_REST_DEFAULT='datahub-rest://rest-endpoint' + + :param datahub_rest_conn_id: Reference to the DataHub Rest connection. + :type datahub_rest_conn_id: str + """ + + conn_name_attr = "datahub_rest_conn_id" + default_conn_name = "datahub_rest_default" + conn_type = "datahub_rest" + hook_name = "DataHub REST Server" + + def __init__(self, datahub_rest_conn_id: str = default_conn_name) -> None: + super().__init__() + self.datahub_rest_conn_id = datahub_rest_conn_id + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + return {} + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behavior""" + return { + "hidden_fields": ["port", "schema", "login"], + "relabeling": { + "host": "Server Endpoint", + }, + } + + def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]: + conn: "Connection" = self.get_connection(self.datahub_rest_conn_id) + + host = conn.host + if not host: + raise AirflowException("host parameter is required") + if conn.port: + if ":" in host: + raise AirflowException( + "host parameter should not contain a port number if the port is specified separately" + ) + host = f"{host}:{conn.port}" + password = conn.password + timeout_sec = conn.extra_dejson.get("timeout_sec") + return (host, password, timeout_sec) + + def make_emitter(self) -> "DatahubRestEmitter": + import datahub.emitter.rest_emitter + + return datahub.emitter.rest_emitter.DatahubRestEmitter(*self._get_config()) + + def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: + emitter = self.make_emitter() + + for mce in mces: + emitter.emit_mce(mce) + + def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None: + emitter = self.make_emitter() + + for mce in mcps: + emitter.emit_mcp(mce) + + +class DatahubKafkaHook(BaseHook): + """ + Creates a DataHub Kafka connection used to send metadata to DataHub. + Takes your kafka broker in the Kafka Broker(host) field. + + URI example: :: + + AIRFLOW_CONN_DATAHUB_KAFKA_DEFAULT='datahub-kafka://kafka-broker' + + :param datahub_kafka_conn_id: Reference to the DataHub Kafka connection. + :type datahub_kafka_conn_id: str + """ + + conn_name_attr = "datahub_kafka_conn_id" + default_conn_name = "datahub_kafka_default" + conn_type = "datahub_kafka" + hook_name = "DataHub Kafka Sink" + + def __init__(self, datahub_kafka_conn_id: str = default_conn_name) -> None: + super().__init__() + self.datahub_kafka_conn_id = datahub_kafka_conn_id + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + return {} + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behavior""" + return { + "hidden_fields": ["port", "schema", "login", "password"], + "relabeling": { + "host": "Kafka Broker", + }, + } + + def _get_config(self) -> "KafkaSinkConfig": + import datahub.ingestion.sink.datahub_kafka + + conn = self.get_connection(self.datahub_kafka_conn_id) + obj = conn.extra_dejson + obj.setdefault("connection", {}) + if conn.host is not None: + if "bootstrap" in obj["connection"]: + raise AirflowException( + "Kafka broker specified twice (present in host and extra)" + ) + obj["connection"]["bootstrap"] = ":".join( + map(str, filter(None, [conn.host, conn.port])) + ) + config = datahub.ingestion.sink.datahub_kafka.KafkaSinkConfig.parse_obj(obj) + return config + + def make_emitter(self) -> "DatahubKafkaEmitter": + import datahub.emitter.kafka_emitter + + sink_config = self._get_config() + return datahub.emitter.kafka_emitter.DatahubKafkaEmitter(sink_config) + + def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: + emitter = self.make_emitter() + errors = [] + + def callback(exc, msg): + if exc: + errors.append(exc) + + for mce in mces: + emitter.emit_mce_async(mce, callback) + + emitter.flush() + + if errors: + raise AirflowException(f"failed to push some MCEs: {errors}") + + def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None: + emitter = self.make_emitter() + errors = [] + + def callback(exc, msg): + if exc: + errors.append(exc) + + for mcp in mcps: + emitter.emit_mcp_async(mcp, callback) + + emitter.flush() + + if errors: + raise AirflowException(f"failed to push some MCPs: {errors}") + + +class DatahubGenericHook(BaseHook): + """ + Emits Metadata Change Events using either the DatahubRestHook or the + DatahubKafkaHook. Set up a DataHub Rest or Kafka connection to use. + + :param datahub_conn_id: Reference to the DataHub connection. + :type datahub_conn_id: str + """ + + def __init__(self, datahub_conn_id: str) -> None: + super().__init__() + self.datahub_conn_id = datahub_conn_id + + def get_underlying_hook(self) -> Union[DatahubRestHook, DatahubKafkaHook]: + conn = self.get_connection(self.datahub_conn_id) + + # We need to figure out the underlying hook type. First check the + # conn_type. If that fails, attempt to guess using the conn id name. + if conn.conn_type == DatahubRestHook.conn_type: + return DatahubRestHook(self.datahub_conn_id) + elif conn.conn_type == DatahubKafkaHook.conn_type: + return DatahubKafkaHook(self.datahub_conn_id) + elif "rest" in self.datahub_conn_id: + return DatahubRestHook(self.datahub_conn_id) + elif "kafka" in self.datahub_conn_id: + return DatahubKafkaHook(self.datahub_conn_id) + else: + raise AirflowException( + f"DataHub cannot handle conn_type {conn.conn_type} in {conn}" + ) + + def make_emitter(self) -> Union["DatahubRestEmitter", "DatahubKafkaEmitter"]: + return self.get_underlying_hook().make_emitter() + + def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: + return self.get_underlying_hook().emit_mces(mces) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/__init__.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/datahub.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/datahub.py new file mode 100644 index 0000000000000..c41bb2b2a1e37 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/lineage/datahub.py @@ -0,0 +1,91 @@ +import json +from typing import TYPE_CHECKING, Dict, List, Optional + +from airflow.configuration import conf +from airflow.lineage.backend import LineageBackend + +from datahub_airflow_plugin._lineage_core import ( + DatahubBasicLineageConfig, + send_lineage_to_datahub, +) + +if TYPE_CHECKING: + from airflow.models.baseoperator import BaseOperator + + +class DatahubLineageConfig(DatahubBasicLineageConfig): + # If set to true, most runtime errors in the lineage backend will be + # suppressed and will not cause the overall task to fail. Note that + # configuration issues will still throw exceptions. + graceful_exceptions: bool = True + + +def get_lineage_config() -> DatahubLineageConfig: + """Load the lineage config from airflow.cfg.""" + + # The kwargs pattern is also used for secret backends. + kwargs_str = conf.get("lineage", "datahub_kwargs", fallback="{}") + kwargs = json.loads(kwargs_str) + + # Continue to support top-level datahub_conn_id config. + datahub_conn_id = conf.get("lineage", "datahub_conn_id", fallback=None) + if datahub_conn_id: + kwargs["datahub_conn_id"] = datahub_conn_id + + return DatahubLineageConfig.parse_obj(kwargs) + + +class DatahubLineageBackend(LineageBackend): + """ + Sends lineage data from tasks to DataHub. + + Configurable via ``airflow.cfg`` as follows: :: + + # For REST-based: + airflow connections add --conn-type 'datahub_rest' 'datahub_rest_default' --conn-host 'http://localhost:8080' + # For Kafka-based (standard Kafka sink config can be passed via extras): + airflow connections add --conn-type 'datahub_kafka' 'datahub_kafka_default' --conn-host 'broker:9092' --conn-extra '{}' + + [lineage] + backend = datahub_provider.lineage.datahub.DatahubLineageBackend + datahub_kwargs = { + "datahub_conn_id": "datahub_rest_default", + "capture_ownership_info": true, + "capture_tags_info": true, + "graceful_exceptions": true } + # The above indentation is important! + """ + + def __init__(self) -> None: + super().__init__() + + # By attempting to get and parse the config, we can detect configuration errors + # ahead of time. The init method is only called in Airflow 2.x. + _ = get_lineage_config() + + # With Airflow 2.0, this can be an instance method. However, with Airflow 1.10.x, this + # method is used statically, even though LineageBackend declares it as an instance variable. + @staticmethod + def send_lineage( + operator: "BaseOperator", + inlets: Optional[List] = None, # unused + outlets: Optional[List] = None, # unused + context: Optional[Dict] = None, + ) -> None: + config = get_lineage_config() + if not config.enabled: + return + + try: + context = context or {} # ensure not None to satisfy mypy + send_lineage_to_datahub( + config, operator, operator.inlets, operator.outlets, context + ) + except Exception as e: + if config.graceful_exceptions: + operator.log.error(e) + operator.log.info( + "Suppressing error because graceful_exceptions is set" + ) + else: + raise diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/__init__.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub.py new file mode 100644 index 0000000000000..109e7ddfe4dfa --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub.py @@ -0,0 +1,63 @@ +from typing import List, Union + +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent + +from datahub_airflow_plugin.hooks.datahub import ( + DatahubGenericHook, + DatahubKafkaHook, + DatahubRestHook, +) + + +class DatahubBaseOperator(BaseOperator): + """ + The DatahubBaseOperator is used as a base operator all DataHub operators. + """ + + ui_color = "#4398c8" + + hook: Union[DatahubRestHook, DatahubKafkaHook] + + # mypy is not a fan of this. Newer versions of Airflow support proper typing for the decorator + # using PEP 612. However, there is not yet a good way to inherit the types of the kwargs from + # the superclass. + @apply_defaults # type: ignore[misc] + def __init__( # type: ignore[no-untyped-def] + self, + *, + datahub_conn_id: str, + **kwargs, + ): + super().__init__(**kwargs) + + self.datahub_conn_id = datahub_conn_id + self.generic_hook = DatahubGenericHook(datahub_conn_id) + + +class DatahubEmitterOperator(DatahubBaseOperator): + """ + Emits a Metadata Change Event to DataHub using either a DataHub + Rest or Kafka connection. + + :param datahub_conn_id: Reference to the DataHub Rest or Kafka Connection. + :type datahub_conn_id: str + """ + + # See above for why these mypy type issues are ignored here. + @apply_defaults # type: ignore[misc] + def __init__( # type: ignore[no-untyped-def] + self, + mces: List[MetadataChangeEvent], + datahub_conn_id: str, + **kwargs, + ): + super().__init__( + datahub_conn_id=datahub_conn_id, + **kwargs, + ) + self.mces = mces + + def execute(self, context): + self.generic_hook.get_underlying_hook().emit_mces(self.mces) diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_operator.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_operator.py new file mode 100644 index 0000000000000..6f93c09a9e287 --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_operator.py @@ -0,0 +1,78 @@ +import datetime +from typing import Any, List, Optional, Sequence, Union + +from airflow.models import BaseOperator +from datahub.api.circuit_breaker import ( + AssertionCircuitBreaker, + AssertionCircuitBreakerConfig, +) + +from datahub_airflow_plugin.hooks.datahub import DatahubRestHook + + +class DataHubAssertionOperator(BaseOperator): + r""" + DataHub Assertion Circuit Breaker Operator. + + :param urn: The DataHub dataset unique identifier. (templated) + :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub + which is set as Airflow connection. + :param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset. + By default it is True. + :param time_delta: If verify_after_last_update is False it checks for assertion within the time delta. + """ + + template_fields: Sequence[str] = ("urn",) + circuit_breaker: AssertionCircuitBreaker + urn: Union[List[str], str] + + def __init__( # type: ignore[no-untyped-def] + self, + *, + urn: Union[List[str], str], + datahub_rest_conn_id: Optional[str] = None, + check_last_assertion_time: bool = True, + time_delta: Optional[datetime.timedelta] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + hook: DatahubRestHook + if datahub_rest_conn_id is not None: + hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) + else: + hook = DatahubRestHook() + + host, password, timeout_sec = hook._get_config() + self.urn = urn + config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig( + datahub_host=host, + datahub_token=password, + timeout=timeout_sec, + verify_after_last_update=check_last_assertion_time, + time_delta=time_delta if time_delta else datetime.timedelta(days=1), + ) + + self.circuit_breaker = AssertionCircuitBreaker(config=config) + + def execute(self, context: Any) -> bool: + if "datahub_silence_circuit_breakers" in context["dag_run"].conf: + self.log.info( + "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" + ) + return True + + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + if isinstance(self.urn, str): + urns = [self.urn] + elif isinstance(self.urn, list): + urns = self.urn + else: + raise Exception(f"urn parameter has invalid type {type(self.urn)}") + + for urn in urns: + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn) + if ret: + raise Exception(f"Dataset {self.urn} is not in consumable state") + + return True diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_sensor.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_sensor.py new file mode 100644 index 0000000000000..16e5d1cbe8b1f --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_assertion_sensor.py @@ -0,0 +1,78 @@ +import datetime +from typing import Any, List, Optional, Sequence, Union + +from airflow.sensors.base import BaseSensorOperator +from datahub.api.circuit_breaker import ( + AssertionCircuitBreaker, + AssertionCircuitBreakerConfig, +) + +from datahub_airflow_plugin.hooks.datahub import DatahubRestHook + + +class DataHubAssertionSensor(BaseSensorOperator): + r""" + DataHub Assertion Circuit Breaker Sensor. + + :param urn: The DataHub dataset unique identifier. (templated) + :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub + which is set as Airflow connection. + :param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset. + By default it is True. + :param time_delta: If verify_after_last_update is False it checks for assertion within the time delta. + """ + + template_fields: Sequence[str] = ("urn",) + circuit_breaker: AssertionCircuitBreaker + urn: Union[List[str], str] + + def __init__( # type: ignore[no-untyped-def] + self, + *, + urn: Union[List[str], str], + datahub_rest_conn_id: Optional[str] = None, + check_last_assertion_time: bool = True, + time_delta: datetime.timedelta = datetime.timedelta(days=1), + **kwargs, + ) -> None: + super().__init__(**kwargs) + hook: DatahubRestHook + if datahub_rest_conn_id is not None: + hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) + else: + hook = DatahubRestHook() + + host, password, timeout_sec = hook._get_config() + self.urn = urn + config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig( + datahub_host=host, + datahub_token=password, + timeout=timeout_sec, + verify_after_last_update=check_last_assertion_time, + time_delta=time_delta, + ) + self.circuit_breaker = AssertionCircuitBreaker(config=config) + + def poke(self, context: Any) -> bool: + if "datahub_silence_circuit_breakers" in context["dag_run"].conf: + self.log.info( + "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" + ) + return True + + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + if isinstance(self.urn, str): + urns = [self.urn] + elif isinstance(self.urn, list): + urns = self.urn + else: + raise Exception(f"urn parameter has invalid type {type(self.urn)}") + + for urn in urns: + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn) + if ret: + self.log.info(f"Dataset {self.urn} is not in consumable state") + return False + + return True diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_operator.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_operator.py new file mode 100644 index 0000000000000..94e105309537b --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_operator.py @@ -0,0 +1,97 @@ +import datetime +from typing import Any, List, Optional, Sequence, Union + +from airflow.sensors.base import BaseSensorOperator +from datahub.api.circuit_breaker import ( + OperationCircuitBreaker, + OperationCircuitBreakerConfig, +) + +from datahub_airflow_plugin.hooks.datahub import DatahubRestHook + + +class DataHubOperationCircuitBreakerOperator(BaseSensorOperator): + r""" + DataHub Operation Circuit Breaker Operator. + + :param urn: The DataHub dataset unique identifier. (templated) + :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub + which is set as Airflow connection. + :param partition: The partition to check the operation. + :param source_type: The partition to check the operation. :ref:`https://datahubproject.io/docs/graphql/enums#operationsourcetype` + + """ + + template_fields: Sequence[str] = ( + "urn", + "partition", + "source_type", + "operation_type", + ) + circuit_breaker: OperationCircuitBreaker + urn: Union[List[str], str] + partition: Optional[str] + source_type: Optional[str] + operation_type: Optional[str] + + def __init__( # type: ignore[no-untyped-def] + self, + *, + urn: Union[List[str], str], + datahub_rest_conn_id: Optional[str] = None, + time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1), + partition: Optional[str] = None, + source_type: Optional[str] = None, + operation_type: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + hook: DatahubRestHook + if datahub_rest_conn_id is not None: + hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) + else: + hook = DatahubRestHook() + + host, password, timeout_sec = hook._get_config() + + self.urn = urn + self.partition = partition + self.operation_type = operation_type + self.source_type = source_type + + config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig( + datahub_host=host, + datahub_token=password, + timeout=timeout_sec, + time_delta=time_delta, + ) + + self.circuit_breaker = OperationCircuitBreaker(config=config) + + def execute(self, context: Any) -> bool: + if "datahub_silence_circuit_breakers" in context["dag_run"].conf: + self.log.info( + "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" + ) + return True + + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + if isinstance(self.urn, str): + urns = [self.urn] + elif isinstance(self.urn, list): + urns = self.urn + else: + raise Exception(f"urn parameter has invalid type {type(self.urn)}") + + for urn in urns: + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + ret = self.circuit_breaker.is_circuit_breaker_active( + urn=urn, + partition=self.partition, + operation_type=self.operation_type, + source_type=self.source_type, + ) + if ret: + raise Exception(f"Dataset {self.urn} is not in consumable state") + + return True diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_sensor.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_sensor.py new file mode 100644 index 0000000000000..434c60754064d --- /dev/null +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/operators/datahub_operation_sensor.py @@ -0,0 +1,100 @@ +import datetime +from typing import Any, List, Optional, Sequence, Union + +from airflow.sensors.base import BaseSensorOperator +from datahub.api.circuit_breaker import ( + OperationCircuitBreaker, + OperationCircuitBreakerConfig, +) + +from datahub_airflow_plugin.hooks.datahub import DatahubRestHook + + +class DataHubOperationCircuitBreakerSensor(BaseSensorOperator): + r""" + DataHub Operation Circuit Breaker Sensor. + + :param urn: The DataHub dataset unique identifier. (templated) + :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub + which is set as Airflow connection. + :param partition: The partition to check the operation. + :param source_type: The source type to filter on. If not set it will accept any source type. + See valid values at: https://datahubproject.io/docs/graphql/enums#operationsourcetype + :param operation_type: The operation type to filter on. If not set it will accept any source type. + See valid values at: https://datahubproject.io/docs/graphql/enums/#operationtype + """ + + template_fields: Sequence[str] = ( + "urn", + "partition", + "source_type", + "operation_type", + ) + circuit_breaker: OperationCircuitBreaker + urn: Union[List[str], str] + partition: Optional[str] + source_type: Optional[str] + operation_type: Optional[str] + + def __init__( # type: ignore[no-untyped-def] + self, + *, + urn: Union[List[str], str], + datahub_rest_conn_id: Optional[str] = None, + time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1), + partition: Optional[str] = None, + source_type: Optional[str] = None, + operation_type: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + hook: DatahubRestHook + if datahub_rest_conn_id is not None: + hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) + else: + hook = DatahubRestHook() + + host, password, timeout_sec = hook._get_config() + + self.urn = urn + self.partition = partition + self.operation_type = operation_type + self.source_type = source_type + + config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig( + datahub_host=host, + datahub_token=password, + timeout=timeout_sec, + time_delta=time_delta, + ) + + self.circuit_breaker = OperationCircuitBreaker(config=config) + + def poke(self, context: Any) -> bool: + if "datahub_silence_circuit_breakers" in context["dag_run"].conf: + self.log.info( + "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" + ) + return True + + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + if isinstance(self.urn, str): + urns = [self.urn] + elif isinstance(self.urn, list): + urns = self.urn + else: + raise Exception(f"urn parameter has invalid type {type(self.urn)}") + + for urn in urns: + self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") + ret = self.circuit_breaker.is_circuit_breaker_active( + urn=urn, + partition=self.partition, + operation_type=self.operation_type, + source_type=self.source_type, + ) + if ret: + self.log.info(f"Dataset {self.urn} is not in consumable state") + return False + + return True diff --git a/metadata-ingestion/tests/unit/test_airflow.py b/metadata-ingestion-modules/airflow-plugin/tests/unit/test_airflow.py similarity index 97% rename from metadata-ingestion/tests/unit/test_airflow.py rename to metadata-ingestion-modules/airflow-plugin/tests/unit/test_airflow.py index 980dc5550fafa..9aa901171cfa6 100644 --- a/metadata-ingestion/tests/unit/test_airflow.py +++ b/metadata-ingestion-modules/airflow-plugin/tests/unit/test_airflow.py @@ -9,12 +9,11 @@ import airflow.configuration import airflow.version +import datahub.emitter.mce_builder as builder import packaging.version import pytest from airflow.lineage import apply_lineage, prepare_lineage from airflow.models import DAG, Connection, DagBag, DagRun, TaskInstance - -import datahub.emitter.mce_builder as builder from datahub_provider import get_provider_info from datahub_provider._airflow_shims import AIRFLOW_PATCHED, EmptyOperator from datahub_provider.entities import Dataset, Urn @@ -23,7 +22,7 @@ assert AIRFLOW_PATCHED -pytestmark = pytest.mark.airflow +# TODO: Remove default_view="tree" arg. Figure out why is default_view being picked as "grid" and how to fix it ? # Approach suggested by https://stackoverflow.com/a/11887885/5004662. AIRFLOW_VERSION = packaging.version.parse(airflow.version.version) @@ -75,7 +74,7 @@ def test_airflow_provider_info(): @pytest.mark.filterwarnings("ignore:.*is deprecated.*") def test_dags_load_with_no_errors(pytestconfig: pytest.Config) -> None: airflow_examples_folder = ( - pytestconfig.rootpath / "src/datahub_provider/example_dags" + pytestconfig.rootpath / "src/datahub_airflow_plugin/example_dags" ) # Note: the .airflowignore file skips the snowflake DAG. @@ -233,7 +232,11 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions): func = mock.Mock() func.__name__ = "foo" - dag = DAG(dag_id="test_lineage_is_sent_to_backend", start_date=DEFAULT_DATE) + dag = DAG( + dag_id="test_lineage_is_sent_to_backend", + start_date=DEFAULT_DATE, + default_view="tree", + ) with dag: op1 = EmptyOperator( @@ -252,6 +255,7 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions): # versions do not require it, but will attempt to find the associated # run_id in the database if execution_date is provided. As such, we # must fake the run_id parameter for newer Airflow versions. + # We need to add type:ignore in else to suppress mypy error in Airflow < 2.2 if AIRFLOW_VERSION < packaging.version.parse("2.2.0"): ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE) # Ignoring type here because DagRun state is just a sring at Airflow 1 @@ -259,7 +263,7 @@ def test_lineage_backend(mock_emit, inlets, outlets, capture_executions): else: from airflow.utils.state import DagRunState - ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}") + ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}") # type: ignore[call-arg] dag_run = DagRun( state=DagRunState.SUCCESS, run_id=f"scheduled_{DEFAULT_DATE.isoformat()}", diff --git a/metadata-ingestion/developing.md b/metadata-ingestion/developing.md index 5d49b9a866a3d..f529590e2ab39 100644 --- a/metadata-ingestion/developing.md +++ b/metadata-ingestion/developing.md @@ -26,6 +26,16 @@ source venv/bin/activate datahub version # should print "DataHub CLI version: unavailable (installed in develop mode)" ``` +### (Optional) Set up your Python environment for developing on Airflow Plugin + +From the repository root: + +```shell +cd metadata-ingestion-modules/airflow-plugin +../../gradlew :metadata-ingestion-modules:airflow-plugin:installDev +source venv/bin/activate +datahub version # should print "DataHub CLI version: unavailable (installed in develop mode)" +``` ### Common setup issues Common issues (click to expand): @@ -183,7 +193,7 @@ pytest -m 'slow_integration' ../gradlew :metadata-ingestion:testFull ../gradlew :metadata-ingestion:check # Run all tests in a single file -../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit/test_airflow.py +../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit/test_bigquery_source.py # Run all tests under tests/unit ../gradlew :metadata-ingestion:testSingle -PtestFile=tests/unit ``` diff --git a/metadata-ingestion/schedule_docs/airflow.md b/metadata-ingestion/schedule_docs/airflow.md index e48710964b01c..95393c3cc9919 100644 --- a/metadata-ingestion/schedule_docs/airflow.md +++ b/metadata-ingestion/schedule_docs/airflow.md @@ -4,9 +4,9 @@ If you are using Apache Airflow for your scheduling then you might want to also We've provided a few examples of how to configure your DAG: -- [`mysql_sample_dag`](../src/datahub_provider/example_dags/mysql_sample_dag.py) embeds the full MySQL ingestion configuration inside the DAG. +- [`mysql_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/mysql_sample_dag.py) embeds the full MySQL ingestion configuration inside the DAG. -- [`snowflake_sample_dag`](../src/datahub_provider/example_dags/snowflake_sample_dag.py) avoids embedding credentials inside the recipe, and instead fetches them from Airflow's [Connections](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection/index.html) feature. You must configure your connections in Airflow to use this approach. +- [`snowflake_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/snowflake_sample_dag.py) avoids embedding credentials inside the recipe, and instead fetches them from Airflow's [Connections](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection/index.html) feature. You must configure your connections in Airflow to use this approach. :::tip @@ -37,6 +37,6 @@ In more advanced cases, you might want to store your ingestion recipe in a file - Create a DAG task to read your DataHub ingestion recipe file and run it. See the example below for reference. - Deploy the DAG file into airflow for scheduling. Typically this involves checking in the DAG file into your dags folder which is accessible to your Airflow instance. -Example: [`generic_recipe_sample_dag`](../src/datahub_provider/example_dags/generic_recipe_sample_dag.py) +Example: [`generic_recipe_sample_dag`](../../metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/example_dags/generic_recipe_sample_dag.py) diff --git a/metadata-ingestion/setup.cfg b/metadata-ingestion/setup.cfg index 59d847395ec47..fad55b99ec938 100644 --- a/metadata-ingestion/setup.cfg +++ b/metadata-ingestion/setup.cfg @@ -75,7 +75,6 @@ disallow_untyped_defs = yes asyncio_mode = auto addopts = --cov=src --cov-report= --cov-config setup.cfg --strict-markers markers = - airflow: marks tests related to airflow (deselect with '-m not airflow') slow_unit: marks tests to only run slow unit tests (deselect with '-m not slow_unit') integration: marks tests to only run in integration (deselect with '-m "not integration"') integration_batch_1: mark tests to only run in batch 1 of integration tests. This is done mainly for parallelisation (deselect with '-m not integration_batch_1') @@ -112,5 +111,3 @@ exclude_lines = omit = # omit codegen src/datahub/metadata/* - # omit example dags - src/datahub_provider/example_dags/* diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index f0b66f8bbfb96..aa01882a44aa6 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -283,8 +283,7 @@ def get_long_description(): }, # Integrations. "airflow": { - "apache-airflow >= 2.0.2", - *rest_common, + f"acryl-datahub-airflow-plugin == {package_metadata['__version__']}", }, "circuit-breaker": { "gql>=3.3.0", @@ -508,8 +507,8 @@ def get_long_description(): "salesforce", "unity-catalog", "nifi", - "vertica" - # airflow is added below + "vertica", + "mode", ] if plugin for dependency in plugins[plugin] @@ -518,9 +517,6 @@ def get_long_description(): dev_requirements = { *base_dev_requirements, - # Extra requirements for Airflow. - "apache-airflow[snowflake]>=2.0.2", # snowflake is used in example dags - "virtualenv", # needed by PythonVirtualenvOperator } full_test_dev_requirements = { diff --git a/metadata-ingestion/src/datahub_provider/__init__.py b/metadata-ingestion/src/datahub_provider/__init__.py index 4c0b2bd8e714e..306076dadf82b 100644 --- a/metadata-ingestion/src/datahub_provider/__init__.py +++ b/metadata-ingestion/src/datahub_provider/__init__.py @@ -1,28 +1 @@ -import datahub - - -# This is needed to allow Airflow to pick up specific metadata fields it needs for -# certain features. We recognize it's a bit unclean to define these in multiple places, -# but at this point it's the only workaround if you'd like your custom conn type to -# show up in the Airflow UI. -def get_provider_info(): - return { - "name": "DataHub", - "description": "`DataHub `__\n", - "connection-types": [ - { - "hook-class-name": "datahub_provider.hooks.datahub.DatahubRestHook", - "connection-type": "datahub_rest", - }, - { - "hook-class-name": "datahub_provider.hooks.datahub.DatahubKafkaHook", - "connection-type": "datahub_kafka", - }, - ], - "hook-class-names": [ - "datahub_provider.hooks.datahub.DatahubRestHook", - "datahub_provider.hooks.datahub.DatahubKafkaHook", - ], - "package-name": datahub.__package_name__, - "versions": [datahub.__version__], - } +from datahub_airflow_plugin import get_provider_info diff --git a/metadata-ingestion/src/datahub_provider/_airflow_compat.py b/metadata-ingestion/src/datahub_provider/_airflow_compat.py index 67c3348ec987c..98b96e32fee78 100644 --- a/metadata-ingestion/src/datahub_provider/_airflow_compat.py +++ b/metadata-ingestion/src/datahub_provider/_airflow_compat.py @@ -1,12 +1,3 @@ -# This module must be imported before any Airflow imports in any of our files. -# The AIRFLOW_PATCHED just helps avoid flake8 errors. +from datahub_airflow_plugin._airflow_compat import AIRFLOW_PATCHED -from datahub.utilities._markupsafe_compat import MARKUPSAFE_PATCHED - -assert MARKUPSAFE_PATCHED - -AIRFLOW_PATCHED = True - -__all__ = [ - "AIRFLOW_PATCHED", -] +__all__ = ["AIRFLOW_PATCHED"] diff --git a/metadata-ingestion/src/datahub_provider/_airflow_shims.py b/metadata-ingestion/src/datahub_provider/_airflow_shims.py index 31e1237c0d21d..d5e4a019a4b81 100644 --- a/metadata-ingestion/src/datahub_provider/_airflow_shims.py +++ b/metadata-ingestion/src/datahub_provider/_airflow_shims.py @@ -1,29 +1,15 @@ -from datahub_provider._airflow_compat import AIRFLOW_PATCHED - -from airflow.models.baseoperator import BaseOperator - -try: - from airflow.models.mappedoperator import MappedOperator - from airflow.models.operator import Operator - from airflow.operators.empty import EmptyOperator -except ModuleNotFoundError: - # Operator isn't a real class, but rather a type alias defined - # as the union of BaseOperator and MappedOperator. - # Since older versions of Airflow don't have MappedOperator, we can just use BaseOperator. - Operator = BaseOperator # type: ignore - MappedOperator = None # type: ignore - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore - -try: - from airflow.sensors.external_task import ExternalTaskSensor -except ImportError: - from airflow.sensors.external_task_sensor import ExternalTaskSensor # type: ignore - -assert AIRFLOW_PATCHED +from datahub_airflow_plugin._airflow_shims import ( + AIRFLOW_PATCHED, + EmptyOperator, + ExternalTaskSensor, + MappedOperator, + Operator, +) __all__ = [ - "Operator", - "MappedOperator", + "AIRFLOW_PATCHED", "EmptyOperator", "ExternalTaskSensor", + "Operator", + "MappedOperator", ] diff --git a/metadata-ingestion/src/datahub_provider/_lineage_core.py b/metadata-ingestion/src/datahub_provider/_lineage_core.py index 07c70eeca4e6d..4305b39cac684 100644 --- a/metadata-ingestion/src/datahub_provider/_lineage_core.py +++ b/metadata-ingestion/src/datahub_provider/_lineage_core.py @@ -1,114 +1,3 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Dict, List +from datahub_airflow_plugin._lineage_core import DatahubBasicLineageConfig -import datahub.emitter.mce_builder as builder -from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult -from datahub.configuration.common import ConfigModel -from datahub.utilities.urns.dataset_urn import DatasetUrn -from datahub_provider.client.airflow_generator import AirflowGenerator -from datahub_provider.entities import _Entity - -if TYPE_CHECKING: - from airflow import DAG - from airflow.models.dagrun import DagRun - from airflow.models.taskinstance import TaskInstance - - from datahub_provider._airflow_shims import Operator - from datahub_provider.hooks.datahub import DatahubGenericHook - - -def _entities_to_urn_list(iolets: List[_Entity]) -> List[DatasetUrn]: - return [DatasetUrn.create_from_string(let.urn) for let in iolets] - - -class DatahubBasicLineageConfig(ConfigModel): - enabled: bool = True - - # DataHub hook connection ID. - datahub_conn_id: str - - # Cluster to associate with the pipelines and tasks. Defaults to "prod". - cluster: str = builder.DEFAULT_FLOW_CLUSTER - - # If true, the owners field of the DAG will be capture as a DataHub corpuser. - capture_ownership_info: bool = True - - # If true, the tags field of the DAG will be captured as DataHub tags. - capture_tags_info: bool = True - - capture_executions: bool = False - - def make_emitter_hook(self) -> "DatahubGenericHook": - # This is necessary to avoid issues with circular imports. - from datahub_provider.hooks.datahub import DatahubGenericHook - - return DatahubGenericHook(self.datahub_conn_id) - - -def send_lineage_to_datahub( - config: DatahubBasicLineageConfig, - operator: "Operator", - inlets: List[_Entity], - outlets: List[_Entity], - context: Dict, -) -> None: - if not config.enabled: - return - - dag: "DAG" = context["dag"] - task: "Operator" = context["task"] - ti: "TaskInstance" = context["task_instance"] - - hook = config.make_emitter_hook() - emitter = hook.make_emitter() - - dataflow = AirflowGenerator.generate_dataflow( - cluster=config.cluster, - dag=dag, - capture_tags=config.capture_tags_info, - capture_owner=config.capture_ownership_info, - ) - dataflow.emit(emitter) - operator.log.info(f"Emitted from Lineage: {dataflow}") - - datajob = AirflowGenerator.generate_datajob( - cluster=config.cluster, - task=task, - dag=dag, - capture_tags=config.capture_tags_info, - capture_owner=config.capture_ownership_info, - ) - datajob.inlets.extend(_entities_to_urn_list(inlets)) - datajob.outlets.extend(_entities_to_urn_list(outlets)) - - datajob.emit(emitter) - operator.log.info(f"Emitted from Lineage: {datajob}") - - if config.capture_executions: - dag_run: "DagRun" = context["dag_run"] - - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=config.cluster, - ti=ti, - dag=dag, - dag_run=dag_run, - datajob=datajob, - emit_templates=False, - ) - - operator.log.info(f"Emitted from Lineage: {dpi}") - - dpi = AirflowGenerator.complete_datajob( - emitter=emitter, - cluster=config.cluster, - ti=ti, - dag=dag, - dag_run=dag_run, - datajob=datajob, - result=InstanceRunResult.SUCCESS, - end_timestamp_millis=int(datetime.utcnow().timestamp() * 1000), - ) - operator.log.info(f"Emitted from Lineage: {dpi}") - - emitter.flush() +__all__ = ["DatahubBasicLineageConfig"] diff --git a/metadata-ingestion/src/datahub_provider/_plugin.py b/metadata-ingestion/src/datahub_provider/_plugin.py index ed2e4e1c93d80..3d74e715bd644 100644 --- a/metadata-ingestion/src/datahub_provider/_plugin.py +++ b/metadata-ingestion/src/datahub_provider/_plugin.py @@ -1,368 +1,3 @@ -from datahub_provider._airflow_compat import AIRFLOW_PATCHED +from datahub_airflow_plugin.datahub_plugin import DatahubPlugin -import contextlib -import logging -import traceback -from typing import Any, Callable, Iterable, List, Optional, Union - -from airflow.configuration import conf -from airflow.lineage import PIPELINE_OUTLETS -from airflow.models.baseoperator import BaseOperator -from airflow.plugins_manager import AirflowPlugin -from airflow.utils.module_loading import import_string -from cattr import structure - -from datahub.api.entities.dataprocess.dataprocess_instance import InstanceRunResult -from datahub_provider._airflow_shims import MappedOperator, Operator -from datahub_provider.client.airflow_generator import AirflowGenerator -from datahub_provider.hooks.datahub import DatahubGenericHook -from datahub_provider.lineage.datahub import DatahubLineageConfig - -assert AIRFLOW_PATCHED -logger = logging.getLogger(__name__) - -TASK_ON_FAILURE_CALLBACK = "on_failure_callback" -TASK_ON_SUCCESS_CALLBACK = "on_success_callback" - - -def get_lineage_config() -> DatahubLineageConfig: - """Load the lineage config from airflow.cfg.""" - - enabled = conf.get("datahub", "enabled", fallback=True) - datahub_conn_id = conf.get("datahub", "conn_id", fallback="datahub_rest_default") - cluster = conf.get("datahub", "cluster", fallback="prod") - graceful_exceptions = conf.get("datahub", "graceful_exceptions", fallback=True) - capture_tags_info = conf.get("datahub", "capture_tags_info", fallback=True) - capture_ownership_info = conf.get( - "datahub", "capture_ownership_info", fallback=True - ) - capture_executions = conf.get("datahub", "capture_executions", fallback=True) - return DatahubLineageConfig( - enabled=enabled, - datahub_conn_id=datahub_conn_id, - cluster=cluster, - graceful_exceptions=graceful_exceptions, - capture_ownership_info=capture_ownership_info, - capture_tags_info=capture_tags_info, - capture_executions=capture_executions, - ) - - -def _task_inlets(operator: "Operator") -> List: - # From Airflow 2.4 _inlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _inlets - if hasattr(operator, "_inlets"): - return operator._inlets # type: ignore[attr-defined, union-attr] - return operator.inlets - - -def _task_outlets(operator: "Operator") -> List: - # From Airflow 2.4 _outlets is dropped and inlets used consistently. Earlier it was not the case, so we have to stick there to _outlets - # We have to use _outlets because outlets is empty in Airflow < 2.4.0 - if hasattr(operator, "_outlets"): - return operator._outlets # type: ignore[attr-defined, union-attr] - return operator.outlets - - -def get_inlets_from_task(task: BaseOperator, context: Any) -> Iterable[Any]: - # TODO: Fix for https://github.com/apache/airflow/commit/1b1f3fabc5909a447a6277cafef3a0d4ef1f01ae - # in Airflow 2.4. - # TODO: ignore/handle airflow's dataset type in our lineage - - inlets: List[Any] = [] - task_inlets = _task_inlets(task) - # From Airflow 2.3 this should be AbstractOperator but due to compatibility reason lets use BaseOperator - if isinstance(task_inlets, (str, BaseOperator)): - inlets = [ - task_inlets, - ] - - if task_inlets and isinstance(task_inlets, list): - inlets = [] - task_ids = ( - {o for o in task_inlets if isinstance(o, str)} - .union(op.task_id for op in task_inlets if isinstance(op, BaseOperator)) - .intersection(task.get_flat_relative_ids(upstream=True)) - ) - - from airflow.lineage import AUTO - - # pick up unique direct upstream task_ids if AUTO is specified - if AUTO.upper() in task_inlets or AUTO.lower() in task_inlets: - print("Picking up unique direct upstream task_ids as AUTO is specified") - task_ids = task_ids.union( - task_ids.symmetric_difference(task.upstream_task_ids) - ) - - inlets = task.xcom_pull( - context, task_ids=list(task_ids), dag_id=task.dag_id, key=PIPELINE_OUTLETS - ) - - # re-instantiate the obtained inlets - inlets = [ - structure(item["data"], import_string(item["type_name"])) - # _get_instance(structure(item, Metadata)) - for sublist in inlets - if sublist - for item in sublist - ] - - for inlet in task_inlets: - if not isinstance(inlet, str): - inlets.append(inlet) - - return inlets - - -def _make_emit_callback( - logger: logging.Logger, -) -> Callable[[Optional[Exception], str], None]: - def emit_callback(err: Optional[Exception], msg: str) -> None: - if err: - logger.error(f"Error sending metadata to datahub: {msg}", exc_info=err) - - return emit_callback - - -def datahub_task_status_callback(context, status): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - dataflow = AirflowGenerator.generate_dataflow( - cluster=context["_datahub_config"].cluster, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - task.log.info(f"Emitting Datahub Dataflow: {dataflow}") - dataflow.emit(emitter, callback=_make_emit_callback(task.log)) - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - task_outlets = _task_outlets(task) - for outlet in task_outlets: - datajob.outlets.append(outlet.urn) - - task.log.info(f"Emitting Datahub Datajob: {datajob}") - datajob.emit(emitter, callback=_make_emit_callback(task.log)) - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitted Start Datahub Dataprocess Instance: {dpi}") - - dpi = AirflowGenerator.complete_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag_run=context["dag_run"], - result=status, - dag=dag, - datajob=datajob, - end_timestamp_millis=int(ti.end_date.timestamp() * 1000), - ) - task.log.info(f"Emitted Completed Data Process Instance: {dpi}") - - emitter.flush() - - -def datahub_pre_execution(context): - ti = context["ti"] - task: "BaseOperator" = ti.task - dag = context["dag"] - - task.log.info("Running Datahub pre_execute method") - - emitter = ( - DatahubGenericHook(context["_datahub_config"].datahub_conn_id) - .get_underlying_hook() - .make_emitter() - ) - - # This code is from the original airflow lineage code -> - # https://github.com/apache/airflow/blob/main/airflow/lineage/__init__.py - inlets = get_inlets_from_task(task, context) - - datajob = AirflowGenerator.generate_datajob( - cluster=context["_datahub_config"].cluster, - task=context["ti"].task, - dag=dag, - capture_tags=context["_datahub_config"].capture_tags_info, - capture_owner=context["_datahub_config"].capture_ownership_info, - ) - - for inlet in inlets: - datajob.inlets.append(inlet.urn) - - task_outlets = _task_outlets(task) - - for outlet in task_outlets: - datajob.outlets.append(outlet.urn) - - task.log.info(f"Emitting Datahub dataJob {datajob}") - datajob.emit(emitter, callback=_make_emit_callback(task.log)) - - if context["_datahub_config"].capture_executions: - dpi = AirflowGenerator.run_datajob( - emitter=emitter, - cluster=context["_datahub_config"].cluster, - ti=context["ti"], - dag=dag, - dag_run=context["dag_run"], - datajob=datajob, - start_timestamp_millis=int(ti.start_date.timestamp() * 1000), - ) - - task.log.info(f"Emitting Datahub Dataprocess Instance: {dpi}") - - emitter.flush() - - -def _wrap_pre_execution(pre_execution): - def custom_pre_execution(context): - config = get_lineage_config() - if config.enabled: - context["_datahub_config"] = config - datahub_pre_execution(context) - - # Call original policy - if pre_execution: - pre_execution(context) - - return custom_pre_execution - - -def _wrap_on_failure_callback(on_failure_callback): - def custom_on_failure_callback(context): - config = get_lineage_config() - if config.enabled: - context["_datahub_config"] = config - try: - datahub_task_status_callback(context, status=InstanceRunResult.FAILURE) - except Exception as e: - if not config.graceful_exceptions: - raise e - else: - print(f"Exception: {traceback.format_exc()}") - - # Call original policy - if on_failure_callback: - on_failure_callback(context) - - return custom_on_failure_callback - - -def _wrap_on_success_callback(on_success_callback): - def custom_on_success_callback(context): - config = get_lineage_config() - if config.enabled: - context["_datahub_config"] = config - try: - datahub_task_status_callback(context, status=InstanceRunResult.SUCCESS) - except Exception as e: - if not config.graceful_exceptions: - raise e - else: - print(f"Exception: {traceback.format_exc()}") - - # Call original policy - if on_success_callback: - on_success_callback(context) - - return custom_on_success_callback - - -def task_policy(task: Union[BaseOperator, MappedOperator]) -> None: - task.log.debug(f"Setting task policy for Dag: {task.dag_id} Task: {task.task_id}") - # task.add_inlets(["auto"]) - # task.pre_execute = _wrap_pre_execution(task.pre_execute) - - # MappedOperator's callbacks don't have setters until Airflow 2.X.X - # https://github.com/apache/airflow/issues/24547 - # We can bypass this by going through partial_kwargs for now - if MappedOperator and isinstance(task, MappedOperator): # type: ignore - on_failure_callback_prop: property = getattr( - MappedOperator, TASK_ON_FAILURE_CALLBACK - ) - on_success_callback_prop: property = getattr( - MappedOperator, TASK_ON_SUCCESS_CALLBACK - ) - if not on_failure_callback_prop.fset or not on_success_callback_prop.fset: - task.log.debug( - "Using MappedOperator's partial_kwargs instead of callback properties" - ) - task.partial_kwargs[TASK_ON_FAILURE_CALLBACK] = _wrap_on_failure_callback( - task.on_failure_callback - ) - task.partial_kwargs[TASK_ON_SUCCESS_CALLBACK] = _wrap_on_success_callback( - task.on_success_callback - ) - return - - task.on_failure_callback = _wrap_on_failure_callback(task.on_failure_callback) # type: ignore - task.on_success_callback = _wrap_on_success_callback(task.on_success_callback) # type: ignore - # task.pre_execute = _wrap_pre_execution(task.pre_execute) - - -def _wrap_task_policy(policy): - if policy and hasattr(policy, "_task_policy_patched_by"): - return policy - - def custom_task_policy(task): - policy(task) - task_policy(task) - - # Add a flag to the policy to indicate that we've patched it. - custom_task_policy._task_policy_patched_by = "datahub_plugin" # type: ignore[attr-defined] - return custom_task_policy - - -def _patch_policy(settings): - if hasattr(settings, "task_policy"): - datahub_task_policy = _wrap_task_policy(settings.task_policy) - settings.task_policy = datahub_task_policy - - -def _patch_datahub_policy(): - with contextlib.suppress(ImportError): - import airflow_local_settings - - _patch_policy(airflow_local_settings) - - from airflow.models.dagbag import settings - - _patch_policy(settings) - - -_patch_datahub_policy() - - -class DatahubPlugin(AirflowPlugin): - name = "datahub_plugin" +__all__ = ["DatahubPlugin"] diff --git a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py index d2d29b00d244f..d50ae152f2b1e 100644 --- a/metadata-ingestion/src/datahub_provider/client/airflow_generator.py +++ b/metadata-ingestion/src/datahub_provider/client/airflow_generator.py @@ -1,509 +1,3 @@ -from datahub_provider._airflow_compat import AIRFLOW_PATCHED +from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, cast - -from airflow.configuration import conf - -from datahub.api.entities.datajob import DataFlow, DataJob -from datahub.api.entities.dataprocess.dataprocess_instance import ( - DataProcessInstance, - InstanceRunResult, -) -from datahub.metadata.schema_classes import DataProcessTypeClass -from datahub.utilities.urns.data_flow_urn import DataFlowUrn -from datahub.utilities.urns.data_job_urn import DataJobUrn - -assert AIRFLOW_PATCHED - -if TYPE_CHECKING: - from airflow import DAG - from airflow.models import DagRun, TaskInstance - - from datahub.emitter.kafka_emitter import DatahubKafkaEmitter - from datahub.emitter.rest_emitter import DatahubRestEmitter - from datahub_provider._airflow_shims import Operator - - -def _task_downstream_task_ids(operator: "Operator") -> Set[str]: - if hasattr(operator, "downstream_task_ids"): - return operator.downstream_task_ids - return operator._downstream_task_id # type: ignore[attr-defined,union-attr] - - -class AirflowGenerator: - @staticmethod - def _get_dependencies( - task: "Operator", dag: "DAG", flow_urn: DataFlowUrn - ) -> List[DataJobUrn]: - from datahub_provider._airflow_shims import ExternalTaskSensor - - # resolve URNs for upstream nodes in subdags upstream of the current task. - upstream_subdag_task_urns: List[DataJobUrn] = [] - - for upstream_task_id in task.upstream_task_ids: - upstream_task = dag.task_dict[upstream_task_id] - - # if upstream task is not a subdag, then skip it - upstream_subdag = getattr(upstream_task, "subdag", None) - if upstream_subdag is None: - continue - - # else, link the leaf tasks of the upstream subdag as upstream tasks - for upstream_subdag_task_id in upstream_subdag.task_dict: - upstream_subdag_task = upstream_subdag.task_dict[ - upstream_subdag_task_id - ] - - upstream_subdag_task_urn = DataJobUrn.create_from_ids( - job_id=upstream_subdag_task_id, data_flow_urn=str(flow_urn) - ) - - # if subdag task is a leaf task, then link it as an upstream task - if len(_task_downstream_task_ids(upstream_subdag_task)) == 0: - upstream_subdag_task_urns.append(upstream_subdag_task_urn) - - # resolve URNs for upstream nodes that trigger the subdag containing the current task. - # (if it is in a subdag at all) - upstream_subdag_triggers: List[DataJobUrn] = [] - - # subdags are always named with 'parent.child' style or Airflow won't run them - # add connection from subdag trigger(s) if subdag task has no upstreams - if ( - dag.is_subdag - and dag.parent_dag is not None - and len(task.upstream_task_ids) == 0 - ): - # filter through the parent dag's tasks and find the subdag trigger(s) - subdags = [ - x for x in dag.parent_dag.task_dict.values() if x.subdag is not None - ] - matched_subdags = [ - x for x in subdags if x.subdag and x.subdag.dag_id == dag.dag_id - ] - - # id of the task containing the subdag - subdag_task_id = matched_subdags[0].task_id - - # iterate through the parent dag's tasks and find the ones that trigger the subdag - for upstream_task_id in dag.parent_dag.task_dict: - upstream_task = dag.parent_dag.task_dict[upstream_task_id] - upstream_task_urn = DataJobUrn.create_from_ids( - data_flow_urn=str(flow_urn), job_id=upstream_task_id - ) - - # if the task triggers the subdag, link it to this node in the subdag - if subdag_task_id in _task_downstream_task_ids(upstream_task): - upstream_subdag_triggers.append(upstream_task_urn) - - # If the operator is an ExternalTaskSensor then we set the remote task as upstream. - # It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie - # jobflow to anothet jobflow. - external_task_upstreams = [] - if task.task_type == "ExternalTaskSensor": - task = cast(ExternalTaskSensor, task) - if hasattr(task, "external_task_id") and task.external_task_id is not None: - external_task_upstreams = [ - DataJobUrn.create_from_ids( - job_id=task.external_task_id, - data_flow_urn=str( - DataFlowUrn.create_from_ids( - orchestrator=flow_urn.get_orchestrator_name(), - flow_id=task.external_dag_id, - env=flow_urn.get_env(), - ) - ), - ) - ] - # exclude subdag operator tasks since these are not emitted, resulting in empty metadata - upstream_tasks = ( - [ - DataJobUrn.create_from_ids(job_id=task_id, data_flow_urn=str(flow_urn)) - for task_id in task.upstream_task_ids - if getattr(dag.task_dict[task_id], "subdag", None) is None - ] - + upstream_subdag_task_urns - + upstream_subdag_triggers - + external_task_upstreams - ) - return upstream_tasks - - @staticmethod - def generate_dataflow( - cluster: str, - dag: "DAG", - capture_owner: bool = True, - capture_tags: bool = True, - ) -> DataFlow: - """ - Generates a Dataflow object from an Airflow DAG - :param cluster: str - name of the cluster - :param dag: DAG - - :param capture_tags: - :param capture_owner: - :return: DataFlow - Data generated dataflow - """ - id = dag.dag_id - orchestrator = "airflow" - description = f"{dag.description}\n\n{dag.doc_md or ''}" - data_flow = DataFlow( - env=cluster, id=id, orchestrator=orchestrator, description=description - ) - - flow_property_bag: Dict[str, str] = {} - - allowed_flow_keys = [ - "_access_control", - "_concurrency", - "_default_view", - "catchup", - "fileloc", - "is_paused_upon_creation", - "start_date", - "tags", - "timezone", - ] - - for key in allowed_flow_keys: - if hasattr(dag, key): - flow_property_bag[key] = repr(getattr(dag, key)) - - data_flow.properties = flow_property_bag - base_url = conf.get("webserver", "base_url") - data_flow.url = f"{base_url}/tree?dag_id={dag.dag_id}" - - if capture_owner and dag.owner: - data_flow.owners.add(dag.owner) - - if capture_tags and dag.tags: - data_flow.tags.update(dag.tags) - - return data_flow - - @staticmethod - def _get_description(task: "Operator") -> Optional[str]: - from airflow.models.baseoperator import BaseOperator - - if not isinstance(task, BaseOperator): - # TODO: Get docs for mapped operators. - return None - - if hasattr(task, "doc") and task.doc: - return task.doc - elif hasattr(task, "doc_md") and task.doc_md: - return task.doc_md - elif hasattr(task, "doc_json") and task.doc_json: - return task.doc_json - elif hasattr(task, "doc_yaml") and task.doc_yaml: - return task.doc_yaml - elif hasattr(task, "doc_rst") and task.doc_yaml: - return task.doc_yaml - return None - - @staticmethod - def generate_datajob( - cluster: str, - task: "Operator", - dag: "DAG", - set_dependencies: bool = True, - capture_owner: bool = True, - capture_tags: bool = True, - ) -> DataJob: - """ - - :param cluster: str - :param task: TaskIntance - :param dag: DAG - :param set_dependencies: bool - whether to extract dependencies from airflow task - :param capture_owner: bool - whether to extract owner from airflow task - :param capture_tags: bool - whether to set tags automatically from airflow task - :return: DataJob - returns the generated DataJob object - """ - dataflow_urn = DataFlowUrn.create_from_ids( - orchestrator="airflow", env=cluster, flow_id=dag.dag_id - ) - datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn) - - # TODO add support for MappedOperator - datajob.description = AirflowGenerator._get_description(task) - - job_property_bag: Dict[str, str] = {} - - allowed_task_keys = [ - "_downstream_task_ids", - "_inlets", - "_outlets", - "_task_type", - "_task_module", - "depends_on_past", - "email", - "label", - "execution_timeout", - "sla", - "sql", - "task_id", - "trigger_rule", - "wait_for_downstream", - # In Airflow 2.3, _downstream_task_ids was renamed to downstream_task_ids - "downstream_task_ids", - # In Airflow 2.4, _inlets and _outlets were removed in favor of non-private versions. - "inlets", - "outlets", - ] - - for key in allowed_task_keys: - if hasattr(task, key): - job_property_bag[key] = repr(getattr(task, key)) - - datajob.properties = job_property_bag - base_url = conf.get("webserver", "base_url") - datajob.url = f"{base_url}/taskinstance/list/?flt1_dag_id_equals={datajob.flow_urn.get_flow_id()}&_flt_3_task_id={task.task_id}" - - if capture_owner and dag.owner: - datajob.owners.add(dag.owner) - - if capture_tags and dag.tags: - datajob.tags.update(dag.tags) - - if set_dependencies: - datajob.upstream_urns.extend( - AirflowGenerator._get_dependencies( - task=task, dag=dag, flow_urn=datajob.flow_urn - ) - ) - - return datajob - - @staticmethod - def create_datajob_instance( - cluster: str, - task: "Operator", - dag: "DAG", - data_job: Optional[DataJob] = None, - ) -> DataProcessInstance: - if data_job is None: - data_job = AirflowGenerator.generate_datajob(cluster, task=task, dag=dag) - dpi = DataProcessInstance.from_datajob( - datajob=data_job, id=task.task_id, clone_inlets=True, clone_outlets=True - ) - return dpi - - @staticmethod - def run_dataflow( - emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], - cluster: str, - dag_run: "DagRun", - start_timestamp_millis: Optional[int] = None, - dataflow: Optional[DataFlow] = None, - ) -> None: - if dataflow is None: - assert dag_run.dag - dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) - - if start_timestamp_millis is None: - assert dag_run.execution_date - start_timestamp_millis = int(dag_run.execution_date.timestamp() * 1000) - - assert dag_run.run_id - dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) - - # This property only exists in Airflow2 - if hasattr(dag_run, "run_type"): - from airflow.utils.types import DagRunType - - if dag_run.run_type == DagRunType.SCHEDULED: - dpi.type = DataProcessTypeClass.BATCH_SCHEDULED - elif dag_run.run_type == DagRunType.MANUAL: - dpi.type = DataProcessTypeClass.BATCH_AD_HOC - else: - if dag_run.run_id.startswith("scheduled__"): - dpi.type = DataProcessTypeClass.BATCH_SCHEDULED - else: - dpi.type = DataProcessTypeClass.BATCH_AD_HOC - - property_bag: Dict[str, str] = {} - property_bag["run_id"] = str(dag_run.run_id) - property_bag["execution_date"] = str(dag_run.execution_date) - property_bag["end_date"] = str(dag_run.end_date) - property_bag["start_date"] = str(dag_run.start_date) - property_bag["creating_job_id"] = str(dag_run.creating_job_id) - property_bag["data_interval_start"] = str(dag_run.data_interval_start) - property_bag["data_interval_end"] = str(dag_run.data_interval_end) - property_bag["external_trigger"] = str(dag_run.external_trigger) - dpi.properties.update(property_bag) - - dpi.emit_process_start( - emitter=emitter, start_timestamp_millis=start_timestamp_millis - ) - - @staticmethod - def complete_dataflow( - emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], - cluster: str, - dag_run: "DagRun", - end_timestamp_millis: Optional[int] = None, - dataflow: Optional[DataFlow] = None, - ) -> None: - """ - - :param emitter: DatahubRestEmitter - the datahub rest emitter to emit the generated mcps - :param cluster: str - name of the cluster - :param dag_run: DagRun - :param end_timestamp_millis: Optional[int] - the completion time in milliseconds if not set the current time will be used. - :param dataflow: Optional[Dataflow] - """ - if dataflow is None: - assert dag_run.dag - dataflow = AirflowGenerator.generate_dataflow(cluster, dag_run.dag) - - assert dag_run.run_id - dpi = DataProcessInstance.from_dataflow(dataflow=dataflow, id=dag_run.run_id) - if end_timestamp_millis is None: - if dag_run.end_date is None: - raise Exception( - f"Dag {dag_run.dag_id}_{dag_run.run_id} is still running and unable to get end_date..." - ) - end_timestamp_millis = int(dag_run.end_date.timestamp() * 1000) - - # We should use DagRunState but it is not available in Airflow 1 - if dag_run.state == "success": - result = InstanceRunResult.SUCCESS - elif dag_run.state == "failed": - result = InstanceRunResult.FAILURE - else: - raise Exception( - f"Result should be either success or failure and it was {dag_run.state}" - ) - - dpi.emit_process_end( - emitter=emitter, - end_timestamp_millis=end_timestamp_millis, - result=result, - result_type="airflow", - ) - - @staticmethod - def run_datajob( - emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], - cluster: str, - ti: "TaskInstance", - dag: "DAG", - dag_run: "DagRun", - start_timestamp_millis: Optional[int] = None, - datajob: Optional[DataJob] = None, - attempt: Optional[int] = None, - emit_templates: bool = True, - ) -> DataProcessInstance: - if datajob is None: - datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag) - - assert dag_run.run_id - dpi = DataProcessInstance.from_datajob( - datajob=datajob, - id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}", - clone_inlets=True, - clone_outlets=True, - ) - job_property_bag: Dict[str, str] = {} - job_property_bag["run_id"] = str(dag_run.run_id) - job_property_bag["duration"] = str(ti.duration) - job_property_bag["start_date"] = str(ti.start_date) - job_property_bag["end_date"] = str(ti.end_date) - job_property_bag["execution_date"] = str(ti.execution_date) - job_property_bag["try_number"] = str(ti.try_number - 1) - job_property_bag["hostname"] = str(ti.hostname) - job_property_bag["max_tries"] = str(ti.max_tries) - # Not compatible with Airflow 1 - if hasattr(ti, "external_executor_id"): - job_property_bag["external_executor_id"] = str(ti.external_executor_id) - job_property_bag["pid"] = str(ti.pid) - job_property_bag["state"] = str(ti.state) - job_property_bag["operator"] = str(ti.operator) - job_property_bag["priority_weight"] = str(ti.priority_weight) - job_property_bag["unixname"] = str(ti.unixname) - job_property_bag["log_url"] = ti.log_url - dpi.properties.update(job_property_bag) - dpi.url = ti.log_url - - # This property only exists in Airflow2 - if hasattr(ti, "dag_run") and hasattr(ti.dag_run, "run_type"): - from airflow.utils.types import DagRunType - - if ti.dag_run.run_type == DagRunType.SCHEDULED: - dpi.type = DataProcessTypeClass.BATCH_SCHEDULED - elif ti.dag_run.run_type == DagRunType.MANUAL: - dpi.type = DataProcessTypeClass.BATCH_AD_HOC - else: - if dag_run.run_id.startswith("scheduled__"): - dpi.type = DataProcessTypeClass.BATCH_SCHEDULED - else: - dpi.type = DataProcessTypeClass.BATCH_AD_HOC - - if start_timestamp_millis is None: - assert ti.start_date - start_timestamp_millis = int(ti.start_date.timestamp() * 1000) - - if attempt is None: - attempt = ti.try_number - - dpi.emit_process_start( - emitter=emitter, - start_timestamp_millis=start_timestamp_millis, - attempt=attempt, - emit_template=emit_templates, - ) - return dpi - - @staticmethod - def complete_datajob( - emitter: Union["DatahubRestEmitter", "DatahubKafkaEmitter"], - cluster: str, - ti: "TaskInstance", - dag: "DAG", - dag_run: "DagRun", - end_timestamp_millis: Optional[int] = None, - result: Optional[InstanceRunResult] = None, - datajob: Optional[DataJob] = None, - ) -> DataProcessInstance: - """ - - :param emitter: DatahubRestEmitter - :param cluster: str - :param ti: TaskInstance - :param dag: DAG - :param dag_run: DagRun - :param end_timestamp_millis: Optional[int] - :param result: Optional[str] One of the result from datahub.metadata.schema_class.RunResultTypeClass - :param datajob: Optional[DataJob] - :return: DataProcessInstance - """ - if datajob is None: - datajob = AirflowGenerator.generate_datajob(cluster, ti.task, dag) - - if end_timestamp_millis is None: - assert ti.end_date - end_timestamp_millis = int(ti.end_date.timestamp() * 1000) - - if result is None: - # We should use TaskInstanceState but it is not available in Airflow 1 - if ti.state == "success": - result = InstanceRunResult.SUCCESS - elif ti.state == "failed": - result = InstanceRunResult.FAILURE - else: - raise Exception( - f"Result should be either success or failure and it was {ti.state}" - ) - - dpi = DataProcessInstance.from_datajob( - datajob=datajob, - id=f"{dag.dag_id}_{ti.task_id}_{dag_run.run_id}", - clone_inlets=True, - clone_outlets=True, - ) - dpi.emit_process_end( - emitter=emitter, - end_timestamp_millis=end_timestamp_millis, - result=result, - result_type="airflow", - ) - return dpi +__all__ = ["AirflowGenerator"] diff --git a/metadata-ingestion/src/datahub_provider/entities.py b/metadata-ingestion/src/datahub_provider/entities.py index bfccc2f22eeb8..13be4ecdad655 100644 --- a/metadata-ingestion/src/datahub_provider/entities.py +++ b/metadata-ingestion/src/datahub_provider/entities.py @@ -1,48 +1,3 @@ -from abc import abstractmethod -from typing import Optional +from datahub_airflow_plugin.entities import Dataset, Urn, _Entity -import attr - -import datahub.emitter.mce_builder as builder -from datahub.utilities.urns.urn import guess_entity_type - - -class _Entity: - @property - @abstractmethod - def urn(self) -> str: - pass - - -@attr.s(auto_attribs=True, str=True) -class Dataset(_Entity): - platform: str - name: str - env: str = builder.DEFAULT_ENV - platform_instance: Optional[str] = None - - @property - def urn(self): - return builder.make_dataset_urn_with_platform_instance( - platform=self.platform, - name=self.name, - platform_instance=self.platform_instance, - env=self.env, - ) - - -@attr.s(str=True) -class Urn(_Entity): - _urn: str = attr.ib() - - @_urn.validator - def _validate_urn(self, attribute, value): - if not value.startswith("urn:"): - raise ValueError("invalid urn provided: urns must start with 'urn:'") - if guess_entity_type(value) != "dataset": - # This is because DataJobs only support Dataset lineage. - raise ValueError("Airflow lineage currently only supports datasets") - - @property - def urn(self): - return self._urn +__all__ = ["_Entity", "Dataset", "Urn"] diff --git a/metadata-ingestion/src/datahub_provider/hooks/datahub.py b/metadata-ingestion/src/datahub_provider/hooks/datahub.py index e2e523fc5d6af..949d98ce631ed 100644 --- a/metadata-ingestion/src/datahub_provider/hooks/datahub.py +++ b/metadata-ingestion/src/datahub_provider/hooks/datahub.py @@ -1,216 +1,8 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union - -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook - -from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( - MetadataChangeEvent, - MetadataChangeProposal, +from datahub_airflow_plugin.hooks.datahub import ( + BaseHook, + DatahubGenericHook, + DatahubKafkaHook, + DatahubRestHook, ) -if TYPE_CHECKING: - from airflow.models.connection import Connection - - from datahub.emitter.kafka_emitter import DatahubKafkaEmitter - from datahub.emitter.rest_emitter import DatahubRestEmitter - from datahub.ingestion.sink.datahub_kafka import KafkaSinkConfig - - -class DatahubRestHook(BaseHook): - """ - Creates a DataHub Rest API connection used to send metadata to DataHub. - Takes the endpoint for your DataHub Rest API in the Server Endpoint(host) field. - - URI example: :: - - AIRFLOW_CONN_DATAHUB_REST_DEFAULT='datahub-rest://rest-endpoint' - - :param datahub_rest_conn_id: Reference to the DataHub Rest connection. - :type datahub_rest_conn_id: str - """ - - conn_name_attr = "datahub_rest_conn_id" - default_conn_name = "datahub_rest_default" - conn_type = "datahub_rest" - hook_name = "DataHub REST Server" - - def __init__(self, datahub_rest_conn_id: str = default_conn_name) -> None: - super().__init__() - self.datahub_rest_conn_id = datahub_rest_conn_id - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - return {} - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behavior""" - return { - "hidden_fields": ["port", "schema", "login"], - "relabeling": { - "host": "Server Endpoint", - }, - } - - def _get_config(self) -> Tuple[str, Optional[str], Optional[int]]: - conn: "Connection" = self.get_connection(self.datahub_rest_conn_id) - - host = conn.host - if not host: - raise AirflowException("host parameter is required") - if conn.port: - if ":" in host: - raise AirflowException( - "host parameter should not contain a port number if the port is specified separately" - ) - host = f"{host}:{conn.port}" - password = conn.password - timeout_sec = conn.extra_dejson.get("timeout_sec") - return (host, password, timeout_sec) - - def make_emitter(self) -> "DatahubRestEmitter": - import datahub.emitter.rest_emitter - - return datahub.emitter.rest_emitter.DatahubRestEmitter(*self._get_config()) - - def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: - emitter = self.make_emitter() - - for mce in mces: - emitter.emit_mce(mce) - - def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None: - emitter = self.make_emitter() - - for mce in mcps: - emitter.emit_mcp(mce) - - -class DatahubKafkaHook(BaseHook): - """ - Creates a DataHub Kafka connection used to send metadata to DataHub. - Takes your kafka broker in the Kafka Broker(host) field. - - URI example: :: - - AIRFLOW_CONN_DATAHUB_KAFKA_DEFAULT='datahub-kafka://kafka-broker' - - :param datahub_kafka_conn_id: Reference to the DataHub Kafka connection. - :type datahub_kafka_conn_id: str - """ - - conn_name_attr = "datahub_kafka_conn_id" - default_conn_name = "datahub_kafka_default" - conn_type = "datahub_kafka" - hook_name = "DataHub Kafka Sink" - - def __init__(self, datahub_kafka_conn_id: str = default_conn_name) -> None: - super().__init__() - self.datahub_kafka_conn_id = datahub_kafka_conn_id - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - return {} - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behavior""" - return { - "hidden_fields": ["port", "schema", "login", "password"], - "relabeling": { - "host": "Kafka Broker", - }, - } - - def _get_config(self) -> "KafkaSinkConfig": - import datahub.ingestion.sink.datahub_kafka - - conn = self.get_connection(self.datahub_kafka_conn_id) - obj = conn.extra_dejson - obj.setdefault("connection", {}) - if conn.host is not None: - if "bootstrap" in obj["connection"]: - raise AirflowException( - "Kafka broker specified twice (present in host and extra)" - ) - obj["connection"]["bootstrap"] = ":".join( - map(str, filter(None, [conn.host, conn.port])) - ) - config = datahub.ingestion.sink.datahub_kafka.KafkaSinkConfig.parse_obj(obj) - return config - - def make_emitter(self) -> "DatahubKafkaEmitter": - import datahub.emitter.kafka_emitter - - sink_config = self._get_config() - return datahub.emitter.kafka_emitter.DatahubKafkaEmitter(sink_config) - - def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: - emitter = self.make_emitter() - errors = [] - - def callback(exc, msg): - if exc: - errors.append(exc) - - for mce in mces: - emitter.emit_mce_async(mce, callback) - - emitter.flush() - - if errors: - raise AirflowException(f"failed to push some MCEs: {errors}") - - def emit_mcps(self, mcps: List[MetadataChangeProposal]) -> None: - emitter = self.make_emitter() - errors = [] - - def callback(exc, msg): - if exc: - errors.append(exc) - - for mcp in mcps: - emitter.emit_mcp_async(mcp, callback) - - emitter.flush() - - if errors: - raise AirflowException(f"failed to push some MCPs: {errors}") - - -class DatahubGenericHook(BaseHook): - """ - Emits Metadata Change Events using either the DatahubRestHook or the - DatahubKafkaHook. Set up a DataHub Rest or Kafka connection to use. - - :param datahub_conn_id: Reference to the DataHub connection. - :type datahub_conn_id: str - """ - - def __init__(self, datahub_conn_id: str) -> None: - super().__init__() - self.datahub_conn_id = datahub_conn_id - - def get_underlying_hook(self) -> Union[DatahubRestHook, DatahubKafkaHook]: - conn = self.get_connection(self.datahub_conn_id) - - # We need to figure out the underlying hook type. First check the - # conn_type. If that fails, attempt to guess using the conn id name. - if conn.conn_type == DatahubRestHook.conn_type: - return DatahubRestHook(self.datahub_conn_id) - elif conn.conn_type == DatahubKafkaHook.conn_type: - return DatahubKafkaHook(self.datahub_conn_id) - elif "rest" in self.datahub_conn_id: - return DatahubRestHook(self.datahub_conn_id) - elif "kafka" in self.datahub_conn_id: - return DatahubKafkaHook(self.datahub_conn_id) - else: - raise AirflowException( - f"DataHub cannot handle conn_type {conn.conn_type} in {conn}" - ) - - def make_emitter(self) -> Union["DatahubRestEmitter", "DatahubKafkaEmitter"]: - return self.get_underlying_hook().make_emitter() - - def emit_mces(self, mces: List[MetadataChangeEvent]) -> None: - return self.get_underlying_hook().emit_mces(mces) +__all__ = ["DatahubRestHook", "DatahubKafkaHook", "DatahubGenericHook", "BaseHook"] diff --git a/metadata-ingestion/src/datahub_provider/lineage/datahub.py b/metadata-ingestion/src/datahub_provider/lineage/datahub.py index 009ce4bb29a97..ffe1adb8255b2 100644 --- a/metadata-ingestion/src/datahub_provider/lineage/datahub.py +++ b/metadata-ingestion/src/datahub_provider/lineage/datahub.py @@ -1,91 +1,6 @@ -import json -from typing import TYPE_CHECKING, Dict, List, Optional - -from airflow.configuration import conf -from airflow.lineage.backend import LineageBackend - -from datahub_provider._lineage_core import ( - DatahubBasicLineageConfig, - send_lineage_to_datahub, +from datahub_airflow_plugin.lineage.datahub import ( + DatahubLineageBackend, + DatahubLineageConfig, ) -if TYPE_CHECKING: - from airflow.models.baseoperator import BaseOperator - - -class DatahubLineageConfig(DatahubBasicLineageConfig): - # If set to true, most runtime errors in the lineage backend will be - # suppressed and will not cause the overall task to fail. Note that - # configuration issues will still throw exceptions. - graceful_exceptions: bool = True - - -def get_lineage_config() -> DatahubLineageConfig: - """Load the lineage config from airflow.cfg.""" - - # The kwargs pattern is also used for secret backends. - kwargs_str = conf.get("lineage", "datahub_kwargs", fallback="{}") - kwargs = json.loads(kwargs_str) - - # Continue to support top-level datahub_conn_id config. - datahub_conn_id = conf.get("lineage", "datahub_conn_id", fallback=None) - if datahub_conn_id: - kwargs["datahub_conn_id"] = datahub_conn_id - - return DatahubLineageConfig.parse_obj(kwargs) - - -class DatahubLineageBackend(LineageBackend): - """ - Sends lineage data from tasks to DataHub. - - Configurable via ``airflow.cfg`` as follows: :: - - # For REST-based: - airflow connections add --conn-type 'datahub_rest' 'datahub_rest_default' --conn-host 'http://localhost:8080' - # For Kafka-based (standard Kafka sink config can be passed via extras): - airflow connections add --conn-type 'datahub_kafka' 'datahub_kafka_default' --conn-host 'broker:9092' --conn-extra '{}' - - [lineage] - backend = datahub_provider.lineage.datahub.DatahubLineageBackend - datahub_kwargs = { - "datahub_conn_id": "datahub_rest_default", - "capture_ownership_info": true, - "capture_tags_info": true, - "graceful_exceptions": true } - # The above indentation is important! - """ - - def __init__(self) -> None: - super().__init__() - - # By attempting to get and parse the config, we can detect configuration errors - # ahead of time. The init method is only called in Airflow 2.x. - _ = get_lineage_config() - - # With Airflow 2.0, this can be an instance method. However, with Airflow 1.10.x, this - # method is used statically, even though LineageBackend declares it as an instance variable. - @staticmethod - def send_lineage( - operator: "BaseOperator", - inlets: Optional[List] = None, # unused - outlets: Optional[List] = None, # unused - context: Optional[Dict] = None, - ) -> None: - config = get_lineage_config() - if not config.enabled: - return - - try: - context = context or {} # ensure not None to satisfy mypy - send_lineage_to_datahub( - config, operator, operator.inlets, operator.outlets, context - ) - except Exception as e: - if config.graceful_exceptions: - operator.log.error(e) - operator.log.info( - "Suppressing error because graceful_exceptions is set" - ) - else: - raise +__all__ = ["DatahubLineageBackend", "DatahubLineageConfig"] diff --git a/metadata-ingestion/src/datahub_provider/operators/datahub.py b/metadata-ingestion/src/datahub_provider/operators/datahub.py index cd1d5187e6d85..08b1807cd4614 100644 --- a/metadata-ingestion/src/datahub_provider/operators/datahub.py +++ b/metadata-ingestion/src/datahub_provider/operators/datahub.py @@ -1,63 +1,6 @@ -from typing import List, Union - -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults - -from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent -from datahub_provider.hooks.datahub import ( - DatahubGenericHook, - DatahubKafkaHook, - DatahubRestHook, +from datahub_airflow_plugin.operators.datahub import ( + DatahubBaseOperator, + DatahubEmitterOperator, ) - -class DatahubBaseOperator(BaseOperator): - """ - The DatahubBaseOperator is used as a base operator all DataHub operators. - """ - - ui_color = "#4398c8" - - hook: Union[DatahubRestHook, DatahubKafkaHook] - - # mypy is not a fan of this. Newer versions of Airflow support proper typing for the decorator - # using PEP 612. However, there is not yet a good way to inherit the types of the kwargs from - # the superclass. - @apply_defaults # type: ignore[misc] - def __init__( # type: ignore[no-untyped-def] - self, - *, - datahub_conn_id: str, - **kwargs, - ): - super().__init__(**kwargs) - - self.datahub_conn_id = datahub_conn_id - self.generic_hook = DatahubGenericHook(datahub_conn_id) - - -class DatahubEmitterOperator(DatahubBaseOperator): - """ - Emits a Metadata Change Event to DataHub using either a DataHub - Rest or Kafka connection. - - :param datahub_conn_id: Reference to the DataHub Rest or Kafka Connection. - :type datahub_conn_id: str - """ - - # See above for why these mypy type issues are ignored here. - @apply_defaults # type: ignore[misc] - def __init__( # type: ignore[no-untyped-def] - self, - mces: List[MetadataChangeEvent], - datahub_conn_id: str, - **kwargs, - ): - super().__init__( - datahub_conn_id=datahub_conn_id, - **kwargs, - ) - self.mces = mces - - def execute(self, context): - self.generic_hook.get_underlying_hook().emit_mces(self.mces) +__all__ = ["DatahubEmitterOperator", "DatahubBaseOperator"] diff --git a/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_operator.py b/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_operator.py index 28be8ad860179..85469c10f271c 100644 --- a/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_operator.py +++ b/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_operator.py @@ -1,78 +1,5 @@ -import datetime -from typing import Any, List, Optional, Sequence, Union - -from airflow.models import BaseOperator - -from datahub.api.circuit_breaker import ( - AssertionCircuitBreaker, - AssertionCircuitBreakerConfig, +from datahub_airflow_plugin.operators.datahub_assertion_operator import ( + DataHubAssertionOperator, ) -from datahub_provider.hooks.datahub import DatahubRestHook - - -class DataHubAssertionOperator(BaseOperator): - r""" - DataHub Assertion Circuit Breaker Operator. - - :param urn: The DataHub dataset unique identifier. (templated) - :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub - which is set as Airflow connection. - :param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset. - By default it is True. - :param time_delta: If verify_after_last_update is False it checks for assertion within the time delta. - """ - - template_fields: Sequence[str] = ("urn",) - circuit_breaker: AssertionCircuitBreaker - urn: Union[List[str], str] - - def __init__( # type: ignore[no-untyped-def] - self, - *, - urn: Union[List[str], str], - datahub_rest_conn_id: Optional[str] = None, - check_last_assertion_time: bool = True, - time_delta: Optional[datetime.timedelta] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - hook: DatahubRestHook - if datahub_rest_conn_id is not None: - hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) - else: - hook = DatahubRestHook() - - host, password, timeout_sec = hook._get_config() - self.urn = urn - config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig( - datahub_host=host, - datahub_token=password, - timeout=timeout_sec, - verify_after_last_update=check_last_assertion_time, - time_delta=time_delta if time_delta else datetime.timedelta(days=1), - ) - - self.circuit_breaker = AssertionCircuitBreaker(config=config) - - def execute(self, context: Any) -> bool: - if "datahub_silence_circuit_breakers" in context["dag_run"].conf: - self.log.info( - "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" - ) - return True - - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - if isinstance(self.urn, str): - urns = [self.urn] - elif isinstance(self.urn, list): - urns = self.urn - else: - raise Exception(f"urn parameter has invalid type {type(self.urn)}") - - for urn in urns: - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn) - if ret: - raise Exception(f"Dataset {self.urn} is not in consumable state") - return True +__all__ = ["DataHubAssertionOperator"] diff --git a/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_sensor.py b/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_sensor.py index ceb970dd8dc7f..e560ecb6145e0 100644 --- a/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_sensor.py +++ b/metadata-ingestion/src/datahub_provider/operators/datahub_assertion_sensor.py @@ -1,78 +1,5 @@ -import datetime -from typing import Any, List, Optional, Sequence, Union - -from airflow.sensors.base import BaseSensorOperator - -from datahub.api.circuit_breaker import ( - AssertionCircuitBreaker, - AssertionCircuitBreakerConfig, +from datahub_airflow_plugin.operators.datahub_assertion_sensor import ( + DataHubAssertionSensor, ) -from datahub_provider.hooks.datahub import DatahubRestHook - - -class DataHubAssertionSensor(BaseSensorOperator): - r""" - DataHub Assertion Circuit Breaker Sensor. - - :param urn: The DataHub dataset unique identifier. (templated) - :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub - which is set as Airflow connection. - :param check_last_assertion_time: If set it checks assertions after the last operation was set on the dataset. - By default it is True. - :param time_delta: If verify_after_last_update is False it checks for assertion within the time delta. - """ - - template_fields: Sequence[str] = ("urn",) - circuit_breaker: AssertionCircuitBreaker - urn: Union[List[str], str] - - def __init__( # type: ignore[no-untyped-def] - self, - *, - urn: Union[List[str], str], - datahub_rest_conn_id: Optional[str] = None, - check_last_assertion_time: bool = True, - time_delta: datetime.timedelta = datetime.timedelta(days=1), - **kwargs, - ) -> None: - super().__init__(**kwargs) - hook: DatahubRestHook - if datahub_rest_conn_id is not None: - hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) - else: - hook = DatahubRestHook() - - host, password, timeout_sec = hook._get_config() - self.urn = urn - config: AssertionCircuitBreakerConfig = AssertionCircuitBreakerConfig( - datahub_host=host, - datahub_token=password, - timeout=timeout_sec, - verify_after_last_update=check_last_assertion_time, - time_delta=time_delta, - ) - self.circuit_breaker = AssertionCircuitBreaker(config=config) - - def poke(self, context: Any) -> bool: - if "datahub_silence_circuit_breakers" in context["dag_run"].conf: - self.log.info( - "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" - ) - return True - - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - if isinstance(self.urn, str): - urns = [self.urn] - elif isinstance(self.urn, list): - urns = self.urn - else: - raise Exception(f"urn parameter has invalid type {type(self.urn)}") - - for urn in urns: - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - ret = self.circuit_breaker.is_circuit_breaker_active(urn=urn) - if ret: - self.log.info(f"Dataset {self.urn} is not in consumable state") - return False - return True +__all__ = ["DataHubAssertionSensor"] diff --git a/metadata-ingestion/src/datahub_provider/operators/datahub_operation_operator.py b/metadata-ingestion/src/datahub_provider/operators/datahub_operation_operator.py index 6b2535994c101..6107e70c9eddd 100644 --- a/metadata-ingestion/src/datahub_provider/operators/datahub_operation_operator.py +++ b/metadata-ingestion/src/datahub_provider/operators/datahub_operation_operator.py @@ -1,97 +1,5 @@ -import datetime -from typing import Any, List, Optional, Sequence, Union - -from airflow.sensors.base import BaseSensorOperator - -from datahub.api.circuit_breaker import ( - OperationCircuitBreaker, - OperationCircuitBreakerConfig, +from datahub_airflow_plugin.operators.datahub_operation_operator import ( + DataHubOperationCircuitBreakerOperator, ) -from datahub_provider.hooks.datahub import DatahubRestHook - - -class DataHubOperationCircuitBreakerOperator(BaseSensorOperator): - r""" - DataHub Operation Circuit Breaker Operator. - - :param urn: The DataHub dataset unique identifier. (templated) - :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub - which is set as Airflow connection. - :param partition: The partition to check the operation. - :param source_type: The partition to check the operation. :ref:`https://datahubproject.io/docs/graphql/enums#operationsourcetype` - - """ - - template_fields: Sequence[str] = ( - "urn", - "partition", - "source_type", - "operation_type", - ) - circuit_breaker: OperationCircuitBreaker - urn: Union[List[str], str] - partition: Optional[str] - source_type: Optional[str] - operation_type: Optional[str] - - def __init__( # type: ignore[no-untyped-def] - self, - *, - urn: Union[List[str], str], - datahub_rest_conn_id: Optional[str] = None, - time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1), - partition: Optional[str] = None, - source_type: Optional[str] = None, - operation_type: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - hook: DatahubRestHook - if datahub_rest_conn_id is not None: - hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) - else: - hook = DatahubRestHook() - - host, password, timeout_sec = hook._get_config() - - self.urn = urn - self.partition = partition - self.operation_type = operation_type - self.source_type = source_type - - config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig( - datahub_host=host, - datahub_token=password, - timeout=timeout_sec, - time_delta=time_delta, - ) - - self.circuit_breaker = OperationCircuitBreaker(config=config) - - def execute(self, context: Any) -> bool: - if "datahub_silence_circuit_breakers" in context["dag_run"].conf: - self.log.info( - "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" - ) - return True - - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - if isinstance(self.urn, str): - urns = [self.urn] - elif isinstance(self.urn, list): - urns = self.urn - else: - raise Exception(f"urn parameter has invalid type {type(self.urn)}") - - for urn in urns: - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - ret = self.circuit_breaker.is_circuit_breaker_active( - urn=urn, - partition=self.partition, - operation_type=self.operation_type, - source_type=self.source_type, - ) - if ret: - raise Exception(f"Dataset {self.urn} is not in consumable state") - return True +__all__ = ["DataHubOperationCircuitBreakerOperator"] diff --git a/metadata-ingestion/src/datahub_provider/operators/datahub_operation_sensor.py b/metadata-ingestion/src/datahub_provider/operators/datahub_operation_sensor.py index 8796215453500..902a342081490 100644 --- a/metadata-ingestion/src/datahub_provider/operators/datahub_operation_sensor.py +++ b/metadata-ingestion/src/datahub_provider/operators/datahub_operation_sensor.py @@ -1,100 +1,5 @@ -import datetime -from typing import Any, List, Optional, Sequence, Union - -from airflow.sensors.base import BaseSensorOperator - -from datahub.api.circuit_breaker import ( - OperationCircuitBreaker, - OperationCircuitBreakerConfig, +from datahub_airflow_plugin.operators.datahub_operation_sensor import ( + DataHubOperationCircuitBreakerSensor, ) -from datahub_provider.hooks.datahub import DatahubRestHook - - -class DataHubOperationCircuitBreakerSensor(BaseSensorOperator): - r""" - DataHub Operation Circuit Breaker Sensor. - - :param urn: The DataHub dataset unique identifier. (templated) - :param datahub_rest_conn_id: The REST datahub connection id to communicate with DataHub - which is set as Airflow connection. - :param partition: The partition to check the operation. - :param source_type: The source type to filter on. If not set it will accept any source type. - See valid values at: https://datahubproject.io/docs/graphql/enums#operationsourcetype - :param operation_type: The operation type to filter on. If not set it will accept any source type. - See valid values at: https://datahubproject.io/docs/graphql/enums/#operationtype - """ - - template_fields: Sequence[str] = ( - "urn", - "partition", - "source_type", - "operation_type", - ) - circuit_breaker: OperationCircuitBreaker - urn: Union[List[str], str] - partition: Optional[str] - source_type: Optional[str] - operation_type: Optional[str] - - def __init__( # type: ignore[no-untyped-def] - self, - *, - urn: Union[List[str], str], - datahub_rest_conn_id: Optional[str] = None, - time_delta: Optional[datetime.timedelta] = datetime.timedelta(days=1), - partition: Optional[str] = None, - source_type: Optional[str] = None, - operation_type: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - hook: DatahubRestHook - if datahub_rest_conn_id is not None: - hook = DatahubRestHook(datahub_rest_conn_id=datahub_rest_conn_id) - else: - hook = DatahubRestHook() - - host, password, timeout_sec = hook._get_config() - - self.urn = urn - self.partition = partition - self.operation_type = operation_type - self.source_type = source_type - - config: OperationCircuitBreakerConfig = OperationCircuitBreakerConfig( - datahub_host=host, - datahub_token=password, - timeout=timeout_sec, - time_delta=time_delta, - ) - - self.circuit_breaker = OperationCircuitBreaker(config=config) - - def poke(self, context: Any) -> bool: - if "datahub_silence_circuit_breakers" in context["dag_run"].conf: - self.log.info( - "Circuit breaker is silenced because datahub_silence_circuit_breakers config is set" - ) - return True - - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - if isinstance(self.urn, str): - urns = [self.urn] - elif isinstance(self.urn, list): - urns = self.urn - else: - raise Exception(f"urn parameter has invalid type {type(self.urn)}") - - for urn in urns: - self.log.info(f"Checking if dataset {self.urn} is ready to be consumed") - ret = self.circuit_breaker.is_circuit_breaker_active( - urn=urn, - partition=self.partition, - operation_type=self.operation_type, - source_type=self.source_type, - ) - if ret: - self.log.info(f"Dataset {self.urn} is not in consumable state") - return False - return True +__all__ = ["DataHubOperationCircuitBreakerSensor"]