From c3d87b89fb3160f2d7e4fb1c26f4b2249cff4013 Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Wed, 27 Nov 2024 16:06:41 -0600 Subject: [PATCH] Add `batch` context object to microbatch jinja context (#11031) * Add `batch_id` to jinja context of microbatch batches * Add changie doc * Update `format_batch_start` to assume `batch_start` is always provided * Add "runtime only" property `batch_context` to `ModelNode` By it being "runtime only" we mean that it doesn't exist on the artifact and thus won't be written out to the manifest artifact. * Begin populating `batch_context` during materialization execution for microbatch batches * Fix circular import * Fixup MicrobatchBuilder.batch_id property method * Ensure MicrobatchModelRunner doesn't double compile batches We were compiling the node for each batch _twice_. Besides making microbatch models more expensive than they needed to be, double compiling wasn't causing any issue. However the first compilation was happening _before_ we had added the batch context information to the model node for the batch. This was leading to models which try to access the `batch_context` information on the model to blow up, which was undesirable. As such, we've now gone and skipped the first compilation. We've done this similar to how SavedQuery nodes skip compilation. * Add `__post_serialize__` method to `BatchContext` to ensure correct dict shape This is weird, but necessary, I apologize. Mashumaro handles the dictification of this class via a compile time generated `to_dict` method based off of the _typing_ of th class. By default `datetime` types are converted to strings. We don't want that, we want them to stay datetimes. * Update tests to check for `batch_context` * Update `resolve_event_time_filter` to use new `batch_context` * Stop testing for batchless compiled code for microbatch models In 45daec72f4953be4c157e6e2ab5671455a569396 we stopped an extra compilation that was happening per batch prior to the batch_context being loaded. Stopping this extra compilation means that compiled sql for the microbatch model without the event time filter / batch context is no longer produced. We have discussed this and _believe_ it is okay given that this is a new node type that has not hit GA yet. * Rename `ModelNode.batch_context` to `ModelNode.batch` * Rename `build_batch_context` to `build_jinja_context_for_batch` The name `build_batch_context` was confusing as 1) We have a `BatchContext` object, which the method was not building 2) The method builds the jinja context for the batch As such it felt appropriate to rename the method to more accurately communicate what it does. * Rename test macro `invalid_batch_context_macro_sql` to `invalid_batch_jinja_context_macro_sql` This rename was to make it more clear that the jinja context for a batch was being checked, as a batch_context has a slightly different connotation. * Update changie doc --- .../unreleased/Features-20241121-125630.yaml | 6 +++ core/dbt/context/providers.py | 5 ++- core/dbt/contracts/graph/nodes.py | 22 +++++++++++ .../incremental/microbatch.py | 25 ++++++------ core/dbt/task/run.py | 25 +++++++++--- .../functional/microbatch/test_microbatch.py | 38 +++++++++++-------- tests/unit/contracts/graph/test_manifest.py | 1 + .../incremental/test_microbatch.py | 9 ++--- 8 files changed, 90 insertions(+), 41 deletions(-) create mode 100644 .changes/unreleased/Features-20241121-125630.yaml diff --git a/.changes/unreleased/Features-20241121-125630.yaml b/.changes/unreleased/Features-20241121-125630.yaml new file mode 100644 index 00000000000..befd9fac790 --- /dev/null +++ b/.changes/unreleased/Features-20241121-125630.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add `batch` context object to model jinja context +time: 2024-11-21T12:56:30.715473-06:00 +custom: + Author: QMalcolm + Issue: "11025" diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 188ad5480b0..f9d436a7840 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -244,9 +244,10 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF and self.model.config.materialized == "incremental" and self.model.config.incremental_strategy == "microbatch" and self.manifest.use_microbatch_batches(project_name=self.config.project_name) + and self.model.batch is not None ): - start = self.model.config.get("__dbt_internal_microbatch_event_time_start") - end = self.model.config.get("__dbt_internal_microbatch_event_time_end") + start = self.model.batch.event_time_start + end = self.model.batch.event_time_end if start is not None or end is not None: event_time_filter = EventTimeFilter( diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 0eaf758ae5a..8fc39c7621b 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -93,6 +93,7 @@ ConstraintType, ModelLevelConstraint, ) +from dbt_common.dataclass_schema import dbtClassMixin from dbt_common.events.contextvars import set_log_contextvars from dbt_common.events.functions import warn_or_error @@ -442,9 +443,30 @@ def resource_class(cls) -> Type[HookNodeResource]: return HookNodeResource +@dataclass +class BatchContext(dbtClassMixin): + id: str + event_time_start: datetime + event_time_end: datetime + + def __post_serialize__(self, data, context): + # This is insane, but necessary, I apologize. Mashumaro handles the + # dictification of this class via a compile time generated `to_dict` + # method based off of the _typing_ of th class. By default `datetime` + # types are converted to strings. We don't want that, we want them to + # stay datetimes. + # Note: This is safe because the `BatchContext` isn't part of the artifact + # and thus doesn't get written out. + new_data = super().__post_serialize__(data, context) + new_data["event_time_start"] = self.event_time_start + new_data["event_time_end"] = self.event_time_end + return new_data + + @dataclass class ModelNode(ModelResource, CompiledNode): previous_batch_results: Optional[BatchResults] = None + batch: Optional[BatchContext] = None _has_this: Optional[bool] = None def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): diff --git a/core/dbt/materializations/incremental/microbatch.py b/core/dbt/materializations/incremental/microbatch.py index b89c834d4a2..6de6945704c 100644 --- a/core/dbt/materializations/incremental/microbatch.py +++ b/core/dbt/materializations/incremental/microbatch.py @@ -100,25 +100,25 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]: return batches - def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]: + def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]: """ Create context with entries that reflect microbatch model + incremental execution state Assumes self.model has been (re)-compiled with necessary batch filters applied. """ - batch_context: Dict[str, Any] = {} + jinja_context: Dict[str, Any] = {} # Microbatch model properties - batch_context["model"] = self.model.to_dict() - batch_context["sql"] = self.model.compiled_code - batch_context["compiled_code"] = self.model.compiled_code + jinja_context["model"] = self.model.to_dict() + jinja_context["sql"] = self.model.compiled_code + jinja_context["compiled_code"] = self.model.compiled_code # Add incremental context variables for batches running incrementally if incremental_batch: - batch_context["is_incremental"] = lambda: True - batch_context["should_full_refresh"] = lambda: False + jinja_context["is_incremental"] = lambda: True + jinja_context["should_full_refresh"] = lambda: False - return batch_context + return jinja_context @staticmethod def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime: @@ -193,12 +193,11 @@ def truncate_timestamp(timestamp: datetime, batch_size: BatchSize) -> datetime: return truncated @staticmethod - def format_batch_start( - batch_start: Optional[datetime], batch_size: BatchSize - ) -> Optional[str]: - if batch_start is None: - return batch_start + def batch_id(start_time: datetime, batch_size: BatchSize) -> str: + return MicrobatchBuilder.format_batch_start(start_time, batch_size).replace("-", "") + @staticmethod + def format_batch_start(batch_start: datetime, batch_size: BatchSize) -> str: return str( batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start ) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index e52dd8d0abd..55b73be3d80 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -27,7 +27,7 @@ from dbt.config import RuntimeConfig from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import HookNode, ModelNode, ResultNode +from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode from dbt.events.types import ( GenericExceptionOnRun, LogHookEndLine, @@ -341,6 +341,13 @@ def __init__(self, config, adapter, node, node_index: int, num_nodes: int): self.batches: Dict[int, BatchType] = {} self.relation_exists: bool = False + def compile(self, manifest: Manifest): + # The default compile function is _always_ called. However, we do our + # compilation _later_ in `_execute_microbatch_materialization`. This + # meant the node was being compiled _twice_ for each batch. To get around + # this, we've overriden the default compile method to do nothing + return self.node + def set_batch_idx(self, batch_idx: int) -> None: self.batch_idx = batch_idx @@ -353,7 +360,7 @@ def set_batches(self, batches: Dict[int, BatchType]) -> None: def describe_node(self) -> str: return f"{self.node.language} microbatch model {self.get_node_representation()}" - def describe_batch(self, batch_start: Optional[datetime]) -> str: + def describe_batch(self, batch_start: datetime) -> str: # Only visualize date if batch_start year/month/day formatted_batch_start = MicrobatchBuilder.format_batch_start( batch_start, self.node.config.batch_size @@ -530,10 +537,16 @@ def _execute_microbatch_materialization( # call materialization_macro to get a batch-level run result start_time = time.perf_counter() try: - # Set start/end in context prior to re-compiling + # LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+) + # TODO: REMOVE before 1.10 GA model.config["__dbt_internal_microbatch_event_time_start"] = batch[0] model.config["__dbt_internal_microbatch_event_time_end"] = batch[1] - + # Create batch context on model node prior to re-compiling + model.batch = BatchContext( + id=MicrobatchBuilder.batch_id(batch[0], model.config.batch_size), + event_time_start=batch[0], + event_time_end=batch[1], + ) # Recompile node to re-resolve refs with event time filters rendered, update context self.compiler.compile_node( model, @@ -544,10 +557,10 @@ def _execute_microbatch_materialization( ), ) # Update jinja context with batch context members - batch_context = microbatch_builder.build_batch_context( + jinja_context = microbatch_builder.build_jinja_context_for_batch( incremental_batch=self.relation_exists ) - context.update(batch_context) + context.update(jinja_context) # Materialize batch and cache any materialized relations result = MacroGenerator( diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index f0f4097b9c9..e3acc415273 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -64,8 +64,8 @@ select * from {{ ref('microbatch_model') }} """ -invalid_batch_context_macro_sql = """ -{% macro check_invalid_batch_context() %} +invalid_batch_jinja_context_macro_sql = """ +{% macro check_invalid_batch_jinja_context() %} {% if model is not mapping %} {{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }} @@ -83,9 +83,9 @@ """ microbatch_model_with_context_checks_sql = """ -{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} +{{ config(pre_hook="{{ check_invalid_batch_jinja_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} -{{ check_invalid_batch_context() }} +{{ check_invalid_batch_jinja_context() }} select * from {{ ref('input_model') }} """ @@ -404,7 +404,7 @@ class TestMicrobatchJinjaContext(BaseMicrobatchTest): @pytest.fixture(scope="class") def macros(self): - return {"check_batch_context.sql": invalid_batch_context_macro_sql} + return {"check_batch_jinja_context.sql": invalid_batch_jinja_context_macro_sql} @pytest.fixture(scope="class") def models(self): @@ -498,6 +498,13 @@ def test_run_with_event_time(self, project): {{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} {{ log("start: "~ model.config.__dbt_internal_microbatch_event_time_start, info=True)}} {{ log("end: "~ model.config.__dbt_internal_microbatch_event_time_end, info=True)}} +{% if model.batch %} +{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}} +{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}} +{{ log("batch.id: "~ model.batch.id, info=True)}} +{{ log("start timezone: "~ model.batch.event_time_start.tzinfo, info=True)}} +{{ log("end timezone: "~ model.batch.event_time_end.tzinfo, info=True)}} +{% endif %} select * from {{ ref('input_model') }} """ @@ -516,12 +523,23 @@ def test_run_with_event_time_logs(self, project): assert "start: 2020-01-01 00:00:00+00:00" in logs assert "end: 2020-01-02 00:00:00+00:00" in logs + assert "batch.event_time_start: 2020-01-01 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-02 00:00:00+00:00" in logs + assert "batch.id: 20200101" in logs + assert "start timezone: UTC" in logs + assert "end timezone: UTC" in logs assert "start: 2020-01-02 00:00:00+00:00" in logs assert "end: 2020-01-03 00:00:00+00:00" in logs + assert "batch.event_time_start: 2020-01-02 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-03 00:00:00+00:00" in logs + assert "batch.id: 20200102" in logs assert "start: 2020-01-03 00:00:00+00:00" in logs assert "end: 2020-01-03 13:57:00+00:00" in logs + assert "batch.event_time_start: 2020-01-03 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-03 13:57:00+00:00" in logs + assert "batch.id: 20200103" in logs microbatch_model_failing_incremental_partition_sql = """ @@ -675,16 +693,6 @@ def test_run_with_event_time(self, project): with patch_microbatch_end_time("2020-01-03 13:57:00"): run_dbt(["run"]) - # Compiled paths - compiled model without filter only - assert read_file( - project.project_root, - "target", - "compiled", - "test", - "models", - "microbatch_model.sql", - ) - # Compiled paths - batch compilations assert read_file( project.project_root, diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 0f3a80e5039..3505ee80037 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -96,6 +96,7 @@ "deprecation_date", "defer_relation", "time_spine", + "batch", } ) diff --git a/tests/unit/materializations/incremental/test_microbatch.py b/tests/unit/materializations/incremental/test_microbatch.py index f114d8649c3..3d827a79975 100644 --- a/tests/unit/materializations/incremental/test_microbatch.py +++ b/tests/unit/materializations/incremental/test_microbatch.py @@ -489,11 +489,11 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_ assert len(actual_batches) == len(expected_batches) assert actual_batches == expected_batches - def test_build_batch_context_incremental_batch(self, microbatch_model): + def test_build_jinja_context_for_incremental_batch(self, microbatch_model): microbatch_builder = MicrobatchBuilder( model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None ) - context = microbatch_builder.build_batch_context(incremental_batch=True) + context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code @@ -502,11 +502,11 @@ def test_build_batch_context_incremental_batch(self, microbatch_model): assert context["is_incremental"]() is True assert context["should_full_refresh"]() is False - def test_build_batch_context_incremental_batch_false(self, microbatch_model): + def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model): microbatch_builder = MicrobatchBuilder( model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None ) - context = microbatch_builder.build_batch_context(incremental_batch=False) + context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code @@ -605,7 +605,6 @@ def test_truncate_timestamp(self, timestamp, batch_size, expected_timestamp): @pytest.mark.parametrize( "batch_size,batch_start,expected_formatted_batch_start", [ - (None, None, None), (BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"), (BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"), (BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"),