Skip to content

Commit

Permalink
Merge branch 'main' into microbatch-warn-when-no-event-time-config-input
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Oct 28, 2024
2 parents 6e58956 + 3224589 commit 907d923
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241028-132751.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: 'Fix ''model'' jinja context variable type to dict '
time: 2024-10-28T13:27:51.604093-04:00
custom:
Author: michelleark
Issue: "10927"
22 changes: 21 additions & 1 deletion core/dbt/materializations/incremental/microbatch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import List, Optional
from typing import Any, Dict, List, Optional

import pytz

Expand Down Expand Up @@ -99,6 +99,26 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]:

return batches

def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]:
"""
Create context with entries that reflect microbatch model + incremental execution state
Assumes self.model has been (re)-compiled with necessary batch filters applied.
"""
batch_context: Dict[str, Any] = {}

# Microbatch model properties
batch_context["model"] = self.model.to_dict()
batch_context["sql"] = self.model.compiled_code
batch_context["compiled_code"] = self.model.compiled_code

# Add incremental context variables for batches running incrementally
if incremental_batch:
batch_context["is_incremental"] = lambda: True
batch_context["should_full_refresh"] = lambda: False

return batch_context

@staticmethod
def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime:
"""Truncates the passed in timestamp based on the batch_size and then applies the offset by the batch_size.
Expand Down
49 changes: 31 additions & 18 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
import time
from copy import deepcopy
from dataclasses import asdict
from datetime import datetime
from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type

Expand Down Expand Up @@ -100,7 +101,14 @@ def get_execution_status(sql: str, adapter: BaseAdapter) -> Tuple[RunStatus, str
return status, message


def track_model_run(index, num_nodes, run_model_result):
def _get_adapter_info(adapter, run_model_result) -> Dict[str, Any]:
"""Each adapter returns a dataclass with a flexible dictionary for
adapter-specific fields. Only the non-'model_adapter_details' fields
are guaranteed cross adapter."""
return asdict(adapter.get_adapter_run_info(run_model_result.node.config)) if adapter else {}


def track_model_run(index, num_nodes, run_model_result, adapter=None):
if tracking.active_user is None:
raise DbtInternalError("cannot track model run with no active user")
invocation_id = get_invocation_id()
Expand All @@ -116,6 +124,7 @@ def track_model_run(index, num_nodes, run_model_result):
contract_enforced = False
versioned = False
incremental_strategy = None

tracking.track_model_run(
{
"invocation_id": invocation_id,
Expand All @@ -135,6 +144,7 @@ def track_model_run(index, num_nodes, run_model_result):
"contract_enforced": contract_enforced,
"access": access,
"versioned": versioned,
"adapter_info": _get_adapter_info(adapter, run_model_result),
}
)

Expand Down Expand Up @@ -283,7 +293,7 @@ def before_execute(self) -> None:
self.print_start_line()

def after_execute(self, result) -> None:
track_model_run(self.node_index, self.num_nodes, result)
track_model_run(self.node_index, self.num_nodes, result, adapter=self.adapter)
self.print_result_line(result)

def _build_run_model_result(self, model, context, elapsed_time: float = 0.0):
Expand Down Expand Up @@ -489,28 +499,29 @@ def _execute_microbatch_materialization(
materialization_macro: MacroProtocol,
) -> List[RunResult]:
batch_results: List[RunResult] = []
microbatch_builder = MicrobatchBuilder(
model=model,
is_incremental=self._is_incremental(model),
event_time_start=getattr(self.config.args, "EVENT_TIME_START", None),
event_time_end=getattr(self.config.args, "EVENT_TIME_END", None),
default_end_time=self.config.invoked_at,
)
# Indicates whether current batch should be run incrementally
incremental_batch = False

# Note currently (9/30/2024) model.batch_info is only ever _not_ `None`
# IFF `dbt retry` is being run and the microbatch model had batches which
# failed on the run of the model (which is being retried)
if model.batch_info is None:
microbatch_builder = MicrobatchBuilder(
model=model,
is_incremental=self._is_incremental(model),
event_time_start=getattr(self.config.args, "EVENT_TIME_START", None),
event_time_end=getattr(self.config.args, "EVENT_TIME_END", None),
default_end_time=self.config.invoked_at,
)
end = microbatch_builder.build_end_time()
start = microbatch_builder.build_start_time(end)
batches = microbatch_builder.build_batches(start, end)
else:
batches = model.batch_info.failed
# if there is batch info, then don't run as full_refresh and do force is_incremental
# If there is batch info, then don't run as full_refresh and do force is_incremental
# not doing this risks blowing away the work that has already been done
if self._has_relation(model=model):
context["is_incremental"] = lambda: True
context["should_full_refresh"] = lambda: False
incremental_batch = True

# iterate over each batch, calling materialization_macro to get a batch-level run result
for batch_idx, batch in enumerate(batches):
Expand All @@ -532,9 +543,11 @@ def _execute_microbatch_materialization(
batch[0], model.config.batch_size
),
)
context["model"] = model
context["sql"] = model.compiled_code
context["compiled_code"] = model.compiled_code
# Update jinja context with batch context members
batch_context = microbatch_builder.build_batch_context(
incremental_batch=incremental_batch
)
context.update(batch_context)

# Materialize batch and cache any materialized relations
result = MacroGenerator(
Expand All @@ -547,9 +560,9 @@ def _execute_microbatch_materialization(
batch_run_result = self._build_succesful_run_batch_result(
model, context, batch, time.perf_counter() - start_time
)
# Update context vars for future batches
context["is_incremental"] = lambda: True
context["should_full_refresh"] = lambda: False
# At least one batch has been inserted successfully!
incremental_batch = True

except Exception as e:
exception = e
batch_run_result = self._build_failed_run_batch_result(
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
RESOURCE_COUNTS = "iglu:com.dbt/resource_counts/jsonschema/1-0-1"
RPC_REQUEST_SPEC = "iglu:com.dbt/rpc_request/jsonschema/1-0-1"
RUNNABLE_TIMING = "iglu:com.dbt/runnable/jsonschema/1-0-0"
RUN_MODEL_SPEC = "iglu:com.dbt/run_model/jsonschema/1-0-4"
RUN_MODEL_SPEC = "iglu:com.dbt/run_model/jsonschema/1-1-0"
PLUGIN_GET_NODES = "iglu:com.dbt/plugin_get_nodes/jsonschema/1-0-0"

SNOWPLOW_TRACKER_VERSION = Version(snowplow_version)
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/microbatch/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@
select * from {{ ref('input_model') }}
"""

invalid_batch_context_macro_sql = """
{% macro check_invalid_batch_context() %}
{% if model is not mapping %}
{{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }}
{% elif compiled_code and compiled_code is not string %}
{{ exceptions.raise_compiler_error("`compiled_code` is invalid: expected string type") }}
{% elif sql and sql is not string %}
{{ exceptions.raise_compiler_error("`sql` is invalid: expected string type") }}
{% elif is_incremental is not callable %}
{{ exceptions.raise_compiler_error("`is_incremental()` is invalid: expected callable type") }}
{% elif should_full_refresh is not callable %}
{{ exceptions.raise_compiler_error("`should_full_refresh()` is invalid: expected callable type") }}
{% endif %}
{% endmacro %}
"""

microbatch_model_with_context_checks_sql = """
{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
{{ check_invalid_batch_context() }}
select * from {{ ref('input_model') }}
"""

microbatch_model_downstream_sql = """
{{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
select * from {{ ref('microbatch_model') }}
Expand Down Expand Up @@ -325,6 +350,27 @@ def test_run_with_event_time(self, project):
self.assert_row_count(project, "microbatch_model", 5)


class TestMicrobatchJinjaContext(BaseMicrobatchTest):

@pytest.fixture(scope="class")
def macros(self):
return {"check_batch_context.sql": invalid_batch_context_macro_sql}

@pytest.fixture(scope="class")
def models(self):
return {
"input_model.sql": input_model_sql,
"microbatch_model.sql": microbatch_model_with_context_checks_sql,
}

@mock.patch.dict(os.environ, {"DBT_EXPERIMENTAL_MICROBATCH": "True"})
def test_run_with_event_time(self, project):
# initial run -- backfills all data
with patch_microbatch_end_time("2020-01-03 13:57:00"):
run_dbt(["run"])
self.assert_row_count(project, "microbatch_model", 3)


class TestMicrobatchWithInputWithoutEventTime(BaseMicrobatchTest):
@pytest.fixture(scope="class")
def models(self):
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/materializations/incremental/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,33 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_
assert len(actual_batches) == len(expected_batches)
assert actual_batches == expected_batches

def test_build_batch_context_incremental_batch(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=True)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
assert context["compiled_code"] == microbatch_model.compiled_code

assert context["is_incremental"]() is True
assert context["should_full_refresh"]() is False

def test_build_batch_context_incremental_batch_false(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=False)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
assert context["compiled_code"] == microbatch_model.compiled_code

# Only build is_incremental callables when not first batch
assert "is_incremental" not in context
assert "should_full_refresh" not in context

@pytest.mark.parametrize(
"timestamp,batch_size,offset,expected_timestamp",
[
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/task/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from argparse import Namespace
from dataclasses import dataclass
from datetime import datetime, timedelta
from importlib import import_module
from typing import Optional
from unittest.mock import MagicMock, patch

Expand All @@ -18,7 +19,7 @@
from dbt.contracts.graph.nodes import ModelNode
from dbt.events.types import LogModelResult
from dbt.flags import get_flags, set_from_args
from dbt.task.run import ModelRunner, RunTask
from dbt.task.run import ModelRunner, RunTask, _get_adapter_info
from dbt.tests.util import safe_set_invocation_context
from dbt_common.events.base_types import EventLevel
from dbt_common.events.event_manager_client import add_callback_to_manager
Expand Down Expand Up @@ -68,6 +69,22 @@ def test_run_task_preserve_edges():
mock_node_selector.get_graph_queue.assert_called_with(mock_spec, True)


def test_tracking_fails_safely_for_missing_adapter():
assert {} == _get_adapter_info(None, {})


def test_adapter_info_tracking():
mock_run_result = MagicMock()
mock_run_result.node = MagicMock()
mock_run_result.node.config = {}
assert _get_adapter_info(PostgresAdapter, mock_run_result) == {
"model_adapter_details": {},
"adapter_name": PostgresAdapter.__name__.split("Adapter")[0].lower(),
"adapter_version": import_module("dbt.adapters.postgres.__version__").version,
"base_adapter_version": import_module("dbt.adapters.__about__").version,
}


class TestModelRunner:
@pytest.fixture
def log_model_result_catcher(self) -> EventCatcher:
Expand Down

0 comments on commit 907d923

Please sign in to comment.