From fce3a583348162f655282d032eca654dcb67b497 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 29 Mar 2024 10:06:02 +0800 Subject: [PATCH] Implement context accessor for DatasetEvent extra (#38481) --- .pre-commit-config.yaml | 5 + airflow/datasets/__init__.py | 11 +- airflow/models/taskinstance.py | 14 ++- airflow/utils/context.py | 34 ++++++ airflow/utils/context.pyi | 12 +- contributing-docs/08_static_code_checks.rst | 2 + .../doc/images/output_static-checks.svg | 114 +++++++++--------- .../doc/images/output_static-checks.txt | 2 +- .../src/airflow_breeze/pre_commit_ids.py | 1 + .../authoring-and-scheduling/datasets.rst | 31 ++++- docs/apache-airflow/templates-ref.rst | 2 + .../pre_commit_template_context_key_sync.py | 0 tests/models/test_taskinstance.py | 67 +++++++++- tests/operators/test_python.py | 1 + 14 files changed, 228 insertions(+), 68 deletions(-) mode change 100644 => 100755 scripts/ci/pre_commit/pre_commit_template_context_key_sync.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1829f9b200492..2f347b1c8851c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -629,6 +629,11 @@ repos: entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py pass_filenames: false files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group\.py$ + - id: check-template-context-variable-in-sync + name: Check all template context variable references are in sync + language: python + entry: ./scripts/ci/pre_commit/pre_commit_template_context_key_sync.py + files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$ - id: check-base-operator-usage language: pygrep name: Check BaseOperator core imports diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 2507c69d01b43..d20d3b578e508 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -42,7 +42,14 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N return ProvidersManager().dataset_uri_handlers.get(scheme) -def _sanitize_uri(uri: str) -> str: +def sanitize_uri(uri: str) -> str: + """Sanitize a dataset URI. + + This checks for URI validity, and normalizes the URI if needed. A fully + normalized URI is returned. + + :meta private: + """ if not uri: raise ValueError("Dataset URI cannot be empty") if uri.isspace(): @@ -110,7 +117,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput): """A representation of data dependencies between workflows.""" uri: str = attr.field( - converter=_sanitize_uri, + converter=sanitize_uri, validator=[attr.validators.min_len(1), attr.validators.max_len(3000)], ) extra: dict[str, Any] | None = None diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 8bb9947327d3a..c9bd2ce617154 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -104,7 +104,13 @@ from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.utils import timezone -from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge +from airflow.utils.context import ( + ConnectionAccessor, + Context, + DatasetEventAccessors, + VariableAccessor, + context_merge, +) from airflow.utils.email import send_email from airflow.utils.helpers import prune_dict, render_template_to_string from airflow.utils.log.logging_mixin import LoggingMixin @@ -766,6 +772,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti "dag_run": dag_run, "data_interval_end": timezone.coerce_datetime(data_interval.end), "data_interval_start": timezone.coerce_datetime(data_interval.start), + "dataset_events": DatasetEventAccessors(), "ds": ds, "ds_nodash": ds_nodash, "execution_date": logical_date, @@ -2569,7 +2576,7 @@ def _run_raw_task( session.add(Log(self.state, self)) session.merge(self).task = self.task if self.state == TaskInstanceState.SUCCESS: - self._register_dataset_changes(session=session) + self._register_dataset_changes(events=context["dataset_events"], session=session) session.commit() if self.state == TaskInstanceState.SUCCESS: @@ -2579,7 +2586,7 @@ def _run_raw_task( return None - def _register_dataset_changes(self, *, session: Session) -> None: + def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: Session) -> None: if TYPE_CHECKING: assert self.task @@ -2590,6 +2597,7 @@ def _register_dataset_changes(self, *, session: Session) -> None: dataset_manager.register_dataset_change( task_instance=self, dataset=obj, + extra=events[obj].extra, session=session, ) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 3501ca7dbc22a..033b7aa39d3ba 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -36,8 +36,10 @@ ValuesView, ) +import attrs import lazy_object_proxy +from airflow.datasets import Dataset, sanitize_uri from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.types import NOTSET @@ -54,6 +56,7 @@ "dag_run", "data_interval_end", "data_interval_start", + "dataset_events", "ds", "ds_nodash", "execution_date", @@ -146,6 +149,37 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn +@attrs.define() +class DatasetEventAccessor: + """Wrapper to access a DatasetEvent instance in template.""" + + extra: dict[str, Any] + + +class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]): + """Lazy mapping of dataset event accessors.""" + + def __init__(self) -> None: + self._dict: dict[str, DatasetEventAccessor] = {} + + def __iter__(self) -> Iterator[str]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: + if isinstance(key, str): + uri = sanitize_uri(key) + elif isinstance(key, Dataset): + uri = key.uri + else: + return NotImplemented + if uri not in self._dict: + self._dict[uri] = DatasetEventAccessor({}) + return self._dict[uri] + + class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): """Warn for usage of deprecated context variables in a task.""" diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index eb08201248173..8b5deb4746918 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -26,11 +26,12 @@ # declare "these are defined, but don't error if others are accessed" someday. from __future__ import annotations -from typing import Any, Collection, Container, Iterable, Mapping, overload +from typing import Any, Collection, Container, Iterable, Iterator, Mapping, overload from pendulum import DateTime from airflow.configuration import AirflowConfigParser +from airflow.datasets import Dataset from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.dagrun import DagRun @@ -55,6 +56,14 @@ class VariableAccessor: class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... +class DatasetEventAccessor: + extra: dict[str, Any] + +class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]): + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: ... + # NOTE: Please keep this in sync with the following: # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py # * Table in docs/apache-airflow/templates-ref.rst @@ -65,6 +74,7 @@ class Context(TypedDict, total=False): dag_run: DagRun | DagRunPydantic data_interval_end: DateTime data_interval_start: DateTime + dataset_events: DatasetEventAccessors ds: str ds_nodash: str exception: BaseException | str | None diff --git a/contributing-docs/08_static_code_checks.rst b/contributing-docs/08_static_code_checks.rst index 0b331bf3e95b4..c7be51b6a78fe 100644 --- a/contributing-docs/08_static_code_checks.rst +++ b/contributing-docs/08_static_code_checks.rst @@ -222,6 +222,8 @@ require Breeze Docker image to be built locally. +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-system-tests-tocs | Check that system tests is properly added | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ +| check-template-context-variable-in-sync | Check all template context variable references are in sync | | ++-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-tests-in-the-right-folders | Check if tests are in the right folders | | +-----------------------------------------------------------+--------------------------------------------------------------+---------+ | check-tests-unittest-testcase | Check that unit tests do not inherit from unittest.TestCase | | diff --git a/dev/breeze/doc/images/output_static-checks.svg b/dev/breeze/doc/images/output_static-checks.svg index 679db3dfeec89..a709f8070c9d9 100644 --- a/dev/breeze/doc/images/output_static-checks.svg +++ b/dev/breeze/doc/images/output_static-checks.svg @@ -1,4 +1,4 @@ - +