diff --git a/.changes/unreleased/Features-20241121-125630.yaml b/.changes/unreleased/Features-20241121-125630.yaml new file mode 100644 index 00000000000..befd9fac790 --- /dev/null +++ b/.changes/unreleased/Features-20241121-125630.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add `batch` context object to model jinja context +time: 2024-11-21T12:56:30.715473-06:00 +custom: + Author: QMalcolm + Issue: "11025" diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 188ad5480b0..f9d436a7840 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -244,9 +244,10 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF and self.model.config.materialized == "incremental" and self.model.config.incremental_strategy == "microbatch" and self.manifest.use_microbatch_batches(project_name=self.config.project_name) + and self.model.batch is not None ): - start = self.model.config.get("__dbt_internal_microbatch_event_time_start") - end = self.model.config.get("__dbt_internal_microbatch_event_time_end") + start = self.model.batch.event_time_start + end = self.model.batch.event_time_end if start is not None or end is not None: event_time_filter = EventTimeFilter( diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 0eaf758ae5a..8fc39c7621b 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -93,6 +93,7 @@ ConstraintType, ModelLevelConstraint, ) +from dbt_common.dataclass_schema import dbtClassMixin from dbt_common.events.contextvars import set_log_contextvars from dbt_common.events.functions import warn_or_error @@ -442,9 +443,30 @@ def resource_class(cls) -> Type[HookNodeResource]: return HookNodeResource +@dataclass +class BatchContext(dbtClassMixin): + id: str + event_time_start: datetime + event_time_end: datetime + + def __post_serialize__(self, data, context): + # This is insane, but necessary, I apologize. Mashumaro handles the + # dictification of this class via a compile time generated `to_dict` + # method based off of the _typing_ of th class. By default `datetime` + # types are converted to strings. We don't want that, we want them to + # stay datetimes. + # Note: This is safe because the `BatchContext` isn't part of the artifact + # and thus doesn't get written out. + new_data = super().__post_serialize__(data, context) + new_data["event_time_start"] = self.event_time_start + new_data["event_time_end"] = self.event_time_end + return new_data + + @dataclass class ModelNode(ModelResource, CompiledNode): previous_batch_results: Optional[BatchResults] = None + batch: Optional[BatchContext] = None _has_this: Optional[bool] = None def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): diff --git a/core/dbt/materializations/incremental/microbatch.py b/core/dbt/materializations/incremental/microbatch.py index b89c834d4a2..6de6945704c 100644 --- a/core/dbt/materializations/incremental/microbatch.py +++ b/core/dbt/materializations/incremental/microbatch.py @@ -100,25 +100,25 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]: return batches - def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]: + def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]: """ Create context with entries that reflect microbatch model + incremental execution state Assumes self.model has been (re)-compiled with necessary batch filters applied. """ - batch_context: Dict[str, Any] = {} + jinja_context: Dict[str, Any] = {} # Microbatch model properties - batch_context["model"] = self.model.to_dict() - batch_context["sql"] = self.model.compiled_code - batch_context["compiled_code"] = self.model.compiled_code + jinja_context["model"] = self.model.to_dict() + jinja_context["sql"] = self.model.compiled_code + jinja_context["compiled_code"] = self.model.compiled_code # Add incremental context variables for batches running incrementally if incremental_batch: - batch_context["is_incremental"] = lambda: True - batch_context["should_full_refresh"] = lambda: False + jinja_context["is_incremental"] = lambda: True + jinja_context["should_full_refresh"] = lambda: False - return batch_context + return jinja_context @staticmethod def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime: @@ -193,12 +193,11 @@ def truncate_timestamp(timestamp: datetime, batch_size: BatchSize) -> datetime: return truncated @staticmethod - def format_batch_start( - batch_start: Optional[datetime], batch_size: BatchSize - ) -> Optional[str]: - if batch_start is None: - return batch_start + def batch_id(start_time: datetime, batch_size: BatchSize) -> str: + return MicrobatchBuilder.format_batch_start(start_time, batch_size).replace("-", "") + @staticmethod + def format_batch_start(batch_start: datetime, batch_size: BatchSize) -> str: return str( batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start ) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index e52dd8d0abd..55b73be3d80 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -27,7 +27,7 @@ from dbt.config import RuntimeConfig from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import HookNode, ModelNode, ResultNode +from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode from dbt.events.types import ( GenericExceptionOnRun, LogHookEndLine, @@ -341,6 +341,13 @@ def __init__(self, config, adapter, node, node_index: int, num_nodes: int): self.batches: Dict[int, BatchType] = {} self.relation_exists: bool = False + def compile(self, manifest: Manifest): + # The default compile function is _always_ called. However, we do our + # compilation _later_ in `_execute_microbatch_materialization`. This + # meant the node was being compiled _twice_ for each batch. To get around + # this, we've overriden the default compile method to do nothing + return self.node + def set_batch_idx(self, batch_idx: int) -> None: self.batch_idx = batch_idx @@ -353,7 +360,7 @@ def set_batches(self, batches: Dict[int, BatchType]) -> None: def describe_node(self) -> str: return f"{self.node.language} microbatch model {self.get_node_representation()}" - def describe_batch(self, batch_start: Optional[datetime]) -> str: + def describe_batch(self, batch_start: datetime) -> str: # Only visualize date if batch_start year/month/day formatted_batch_start = MicrobatchBuilder.format_batch_start( batch_start, self.node.config.batch_size @@ -530,10 +537,16 @@ def _execute_microbatch_materialization( # call materialization_macro to get a batch-level run result start_time = time.perf_counter() try: - # Set start/end in context prior to re-compiling + # LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+) + # TODO: REMOVE before 1.10 GA model.config["__dbt_internal_microbatch_event_time_start"] = batch[0] model.config["__dbt_internal_microbatch_event_time_end"] = batch[1] - + # Create batch context on model node prior to re-compiling + model.batch = BatchContext( + id=MicrobatchBuilder.batch_id(batch[0], model.config.batch_size), + event_time_start=batch[0], + event_time_end=batch[1], + ) # Recompile node to re-resolve refs with event time filters rendered, update context self.compiler.compile_node( model, @@ -544,10 +557,10 @@ def _execute_microbatch_materialization( ), ) # Update jinja context with batch context members - batch_context = microbatch_builder.build_batch_context( + jinja_context = microbatch_builder.build_jinja_context_for_batch( incremental_batch=self.relation_exists ) - context.update(batch_context) + context.update(jinja_context) # Materialize batch and cache any materialized relations result = MacroGenerator( diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index f0f4097b9c9..e3acc415273 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -64,8 +64,8 @@ select * from {{ ref('microbatch_model') }} """ -invalid_batch_context_macro_sql = """ -{% macro check_invalid_batch_context() %} +invalid_batch_jinja_context_macro_sql = """ +{% macro check_invalid_batch_jinja_context() %} {% if model is not mapping %} {{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }} @@ -83,9 +83,9 @@ """ microbatch_model_with_context_checks_sql = """ -{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} +{{ config(pre_hook="{{ check_invalid_batch_jinja_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} -{{ check_invalid_batch_context() }} +{{ check_invalid_batch_jinja_context() }} select * from {{ ref('input_model') }} """ @@ -404,7 +404,7 @@ class TestMicrobatchJinjaContext(BaseMicrobatchTest): @pytest.fixture(scope="class") def macros(self): - return {"check_batch_context.sql": invalid_batch_context_macro_sql} + return {"check_batch_jinja_context.sql": invalid_batch_jinja_context_macro_sql} @pytest.fixture(scope="class") def models(self): @@ -498,6 +498,13 @@ def test_run_with_event_time(self, project): {{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }} {{ log("start: "~ model.config.__dbt_internal_microbatch_event_time_start, info=True)}} {{ log("end: "~ model.config.__dbt_internal_microbatch_event_time_end, info=True)}} +{% if model.batch %} +{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}} +{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}} +{{ log("batch.id: "~ model.batch.id, info=True)}} +{{ log("start timezone: "~ model.batch.event_time_start.tzinfo, info=True)}} +{{ log("end timezone: "~ model.batch.event_time_end.tzinfo, info=True)}} +{% endif %} select * from {{ ref('input_model') }} """ @@ -516,12 +523,23 @@ def test_run_with_event_time_logs(self, project): assert "start: 2020-01-01 00:00:00+00:00" in logs assert "end: 2020-01-02 00:00:00+00:00" in logs + assert "batch.event_time_start: 2020-01-01 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-02 00:00:00+00:00" in logs + assert "batch.id: 20200101" in logs + assert "start timezone: UTC" in logs + assert "end timezone: UTC" in logs assert "start: 2020-01-02 00:00:00+00:00" in logs assert "end: 2020-01-03 00:00:00+00:00" in logs + assert "batch.event_time_start: 2020-01-02 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-03 00:00:00+00:00" in logs + assert "batch.id: 20200102" in logs assert "start: 2020-01-03 00:00:00+00:00" in logs assert "end: 2020-01-03 13:57:00+00:00" in logs + assert "batch.event_time_start: 2020-01-03 00:00:00+00:00" in logs + assert "batch.event_time_end: 2020-01-03 13:57:00+00:00" in logs + assert "batch.id: 20200103" in logs microbatch_model_failing_incremental_partition_sql = """ @@ -675,16 +693,6 @@ def test_run_with_event_time(self, project): with patch_microbatch_end_time("2020-01-03 13:57:00"): run_dbt(["run"]) - # Compiled paths - compiled model without filter only - assert read_file( - project.project_root, - "target", - "compiled", - "test", - "models", - "microbatch_model.sql", - ) - # Compiled paths - batch compilations assert read_file( project.project_root, diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 0f3a80e5039..3505ee80037 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -96,6 +96,7 @@ "deprecation_date", "defer_relation", "time_spine", + "batch", } ) diff --git a/tests/unit/materializations/incremental/test_microbatch.py b/tests/unit/materializations/incremental/test_microbatch.py index f114d8649c3..3d827a79975 100644 --- a/tests/unit/materializations/incremental/test_microbatch.py +++ b/tests/unit/materializations/incremental/test_microbatch.py @@ -489,11 +489,11 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_ assert len(actual_batches) == len(expected_batches) assert actual_batches == expected_batches - def test_build_batch_context_incremental_batch(self, microbatch_model): + def test_build_jinja_context_for_incremental_batch(self, microbatch_model): microbatch_builder = MicrobatchBuilder( model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None ) - context = microbatch_builder.build_batch_context(incremental_batch=True) + context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code @@ -502,11 +502,11 @@ def test_build_batch_context_incremental_batch(self, microbatch_model): assert context["is_incremental"]() is True assert context["should_full_refresh"]() is False - def test_build_batch_context_incremental_batch_false(self, microbatch_model): + def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model): microbatch_builder = MicrobatchBuilder( model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None ) - context = microbatch_builder.build_batch_context(incremental_batch=False) + context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code @@ -605,7 +605,6 @@ def test_truncate_timestamp(self, timestamp, batch_size, expected_timestamp): @pytest.mark.parametrize( "batch_size,batch_start,expected_formatted_batch_start", [ - (None, None, None), (BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"), (BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"), (BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"),