Skip to content

Commit

Permalink
Microbatch first last batch serial (#11072) (#11107)
Browse files Browse the repository at this point in the history
* microbatch: split out first and last batch to run in serial

* only run pre_hook on first batch, post_hook on last batch

* refactor: internalize parallel to RunTask._submit_batch

* Add optional `force_sequential` to `_submit_batch` to allow for skipping parallelism check

* Force last batch to run sequentially

* Force first batch to run sequentially

* Remove batch_idx check in `should_run_in_parallel`

`should_run_in_parallel` shouldn't, and no longer needs to, take into
consideration where in batch exists in a larger context. The first and
last batch for a microbatch model are now forced to run sequentially
by `handle_microbatch_model`

* Begin skipping batches if first batch fails

* Write custom `on_skip` for `MicrobatchModelRunner` to better handle when batches are skipped

This was necessary specifically because the default on skip set the `X of Y` part
of the skipped log using the `node_index` and the `num_nodes`. If there was 2
nodes and we are on the 4th batch of the second node, we'd get a message like
`SKIPPED 4 of 2...` which didn't make much sense. We're likely in a future commit
going to add a custom event for logging the start, result, and skipping of batches
for better readability of the logs.

* Add microbatch pre-hook, post-hook, and sequential first/last batch tests

* Fix/Add tests around first batch failure vs latter batch failure

* Correct MicrobatchModelRunner.on_skip to handle skipping the entire node

Previously `MicrobatchModelRunner.on_skip` only handled when a _batch_ of
the model was being skipped. However, that method is also used when the
entire microbatch model is being skipped due to an upstream node error. Because
we previously _weren't_ handling this second case, it'd cause an unhandled
runtime exception. Thus, we now need to check whether we're running a batch or not,
and there is no batch, then use the super's on_skip method.

* Correct conditional logic for setting pre- and post-hooks for batches

Previously we were doing an if+elif for setting pre- and post-hooks
for batches, where in the `if` matched if the batch wasn't the first
batch, and the `elif` matched if the batch wasn't the last batch. The
issue with this is that if the `if` was hit, the `elif` _wouldn't_ be hit.
This caused the first batch to appropriately not run the `post-hook` but
then every hook after would run the `post-hook`.

* Add two new event types `LogStartBatch` and `LogBatchResult`

* Update MicrobatchModelRunner to use new batch specific log events

* Fix event testing

* Update microbatch integration tests to catch batch specific event types

---------

Co-authored-by: Quigley Malcolm <[email protected]>
(cherry picked from commit 03fdb4c)

Co-authored-by: Michelle Ark <[email protected]>
  • Loading branch information
github-actions[bot] and MichelleArk authored Dec 9, 2024
1 parent 4e74e69 commit 9f5f002
Show file tree
Hide file tree
Showing 8 changed files with 575 additions and 219 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20241206-195308.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Ensure pre/post hooks only run on first/last batch respectively for microbatch
model batches
time: 2024-12-06T19:53:08.928793-06:00
custom:
Author: MichelleArk QMalcolm
Issue: 11094 11104
29 changes: 29 additions & 0 deletions core/dbt/events/core_types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,35 @@ message MicrobatchExecutionDebugMsg {
MicrobatchExecutionDebug data = 2;
}

// Q045
message LogStartBatch {
NodeInfo node_info = 1;
string description = 2;
int32 batch_index = 3;
int32 total_batches = 4;
}

message LogStartBatchMsg {
CoreEventInfo info = 1;
LogStartBatch data = 2;
}

// Q046
message LogBatchResult {
NodeInfo node_info = 1;
string description = 2;
string status = 3;
int32 batch_index = 4;
int32 total_batches = 5;
float execution_time = 6;
Group group = 7;
}

message LogBatchResultMsg {
CoreEventInfo info = 1;
LogBatchResult data = 2;
}

// W - Node testing

// Skipped W001
Expand Down
350 changes: 179 additions & 171 deletions core/dbt/events/core_types_pb2.py

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions core/dbt/events/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,51 @@ def message(self) -> str:
return self.msg


class LogStartBatch(InfoLevel):
def code(self) -> str:
return "Q045"

def message(self) -> str:
msg = f"START {self.description}"

# TODO update common so that we can append "batch" in `format_fancy_output_line`
formatted = format_fancy_output_line(
msg=msg,
status="RUN",
index=self.batch_index,
total=self.total_batches,
)
return f"Batch {formatted}"


class LogBatchResult(DynamicLevel):
def code(self) -> str:
return "Q046"

def message(self) -> str:
if self.status == "error":
info = "ERROR creating"
status = red(self.status.upper())
elif self.status == "skipped":
info = "SKIP"
status = yellow(self.status.upper())
else:
info = "OK created"
status = green(self.status)

msg = f"{info} {self.description}"

# TODO update common so that we can append "batch" in `format_fancy_output_line`
formatted = format_fancy_output_line(
msg=msg,
status=status,
index=self.batch_index,
total=self.total_batches,
execution_time=self.execution_time,
)
return f"Batch {formatted}"


# =======================================================
# W - Node testing
# =======================================================
Expand Down
171 changes: 131 additions & 40 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode
from dbt.events.types import (
GenericExceptionOnRun,
LogBatchResult,
LogHookEndLine,
LogHookStartLine,
LogModelResult,
LogStartBatch,
LogStartLine,
MicrobatchExecutionDebug,
)
Expand Down Expand Up @@ -397,15 +399,18 @@ def print_batch_result_line(
if result.status == NodeStatus.Error:
status = result.status
level = EventLevel.ERROR
elif result.status == NodeStatus.Skipped:
status = result.status
level = EventLevel.INFO
else:
status = result.message
level = EventLevel.INFO
fire_event(
LogModelResult(
LogBatchResult(
description=description,
status=status,
index=self.batch_idx + 1,
total=len(self.batches),
batch_index=self.batch_idx + 1,
total_batches=len(self.batches),
execution_time=result.execution_time,
node_info=self.node.node_info,
group=group,
Expand All @@ -423,10 +428,10 @@ def print_batch_start_line(self) -> None:

batch_description = self.describe_batch(batch_start)
fire_event(
LogStartLine(
LogStartBatch(
description=batch_description,
index=self.batch_idx + 1,
total=len(self.batches),
batch_index=self.batch_idx + 1,
total_batches=len(self.batches),
node_info=self.node.node_info,
)
)
Expand Down Expand Up @@ -472,6 +477,25 @@ def merge_batch_results(self, result: RunResult, batch_results: List[RunResult])
if self.node.previous_batch_results is not None:
result.batch_results.successful += self.node.previous_batch_results.successful

def on_skip(self):
# If node.batch is None, then we're dealing with skipping of the entire node
if self.batch_idx is None:
return super().on_skip()
else:
result = RunResult(
node=self.node,
status=RunStatus.Skipped,
timing=[],
thread_id=threading.current_thread().name,
execution_time=0.0,
message="SKIPPED",
adapter_response={},
failures=1,
batch_results=BatchResults(failed=[self.batches[self.batch_idx]]),
)
self.print_batch_result_line(result=result)
return result

def _build_succesful_run_batch_result(
self,
model: ModelNode,
Expand Down Expand Up @@ -602,13 +626,10 @@ def _has_relation(self, model) -> bool:
)
return relation is not None

def _should_run_in_parallel(
self,
relation_exists: bool,
) -> bool:
def should_run_in_parallel(self) -> bool:
if not self.adapter.supports(Capability.MicrobatchConcurrency):
run_in_parallel = False
elif not relation_exists:
elif not self.relation_exists:
# If the relation doesn't exist, we can't run in parallel
run_in_parallel = False
elif self.node.config.concurrent_batches is not None:
Expand Down Expand Up @@ -703,52 +724,122 @@ def handle_microbatch_model(
runner: MicrobatchModelRunner,
pool: ThreadPool,
) -> RunResult:
# Initial run computes batch metadata, unless model is skipped
# Initial run computes batch metadata
result = self.call_runner(runner)
batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists

# Return early if model should be skipped, or there are no batches to execute
if result.status == RunStatus.Skipped:
return result
elif len(runner.batches) == 0:
return result

batch_results: List[RunResult] = []

# Execute batches serially until a relation exists, at which point future batches are run in parallel
relation_exists = runner.relation_exists
batch_idx = 0
while batch_idx < len(runner.batches):
batch_runner = MicrobatchModelRunner(
self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes
)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(runner.batches)

if runner._should_run_in_parallel(relation_exists):
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run concurrently"
)
)
self._submit(pool, [batch_runner], batch_results.append)
else:
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run sequentially"
)
)
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists

# Run first batch not in parallel
relation_exists = self._submit_batch(
node=node,
adapter=runner.adapter,
relation_exists=relation_exists,
batches=batches,
batch_idx=batch_idx,
batch_results=batch_results,
pool=pool,
force_sequential_run=True,
)
batch_idx += 1
skip_batches = batch_results[0].status != RunStatus.Success

# Run all batches except first and last batch, in parallel if possible
while batch_idx < len(runner.batches) - 1:
relation_exists = self._submit_batch(
node=node,
adapter=runner.adapter,
relation_exists=relation_exists,
batches=batches,
batch_idx=batch_idx,
batch_results=batch_results,
pool=pool,
skip=skip_batches,
)
batch_idx += 1

# Wait until all batches have completed
while len(batch_results) != len(runner.batches):
# Wait until all submitted batches have completed
while len(batch_results) != batch_idx:
pass
# Final batch runs once all others complete to ensure post_hook runs at the end
self._submit_batch(
node=node,
adapter=runner.adapter,
relation_exists=relation_exists,
batches=batches,
batch_idx=batch_idx,
batch_results=batch_results,
pool=pool,
force_sequential_run=True,
skip=skip_batches,
)

# Finalize run: merge results, track model run, and print final result line
runner.merge_batch_results(result, batch_results)
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
runner.print_result_line(result)

return result

def _submit_batch(
self,
node: ModelNode,
adapter: BaseAdapter,
relation_exists: bool,
batches: Dict[int, BatchType],
batch_idx: int,
batch_results: List[RunResult],
pool: ThreadPool,
force_sequential_run: bool = False,
skip: bool = False,
):
node_copy = deepcopy(node)
# Only run pre_hook(s) for first batch
if batch_idx != 0:
node_copy.config.pre_hook = []

# Only run post_hook(s) for last batch
if batch_idx != len(batches) - 1:
node_copy.config.post_hook = []

# TODO: We should be doing self.get_runner, however doing so
# currently causes the tracking of how many nodes there are to
# increment when we don't want it to
batch_runner = MicrobatchModelRunner(
self.config, adapter, node_copy, self.run_count, self.num_nodes
)
batch_runner.set_batch_idx(batch_idx)
batch_runner.set_relation_exists(relation_exists)
batch_runner.set_batches(batches)

if skip:
batch_runner.do_skip()

if not force_sequential_run and batch_runner.should_run_in_parallel():
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run concurrently"
)
)
self._submit(pool, [batch_runner], batch_results.append)
else:
fire_event(
MicrobatchExecutionDebug(
msg=f"{batch_runner.describe_batch} is being run sequentially"
)
)
batch_results.append(self.call_runner(batch_runner))
relation_exists = batch_runner.relation_exists

return relation_exists

def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
package_name = hook.package_name
if package_name == self.config.project_name:
Expand Down
Loading

0 comments on commit 9f5f002

Please sign in to comment.