Skip to content

Commit

Permalink
Merge branch 'main' into qmalcolm--10874-make-event-time-start-end-mu…
Browse files Browse the repository at this point in the history
…tually-required
  • Loading branch information
QMalcolm committed Oct 21, 2024
2 parents 7b74ff0 + 7920b0e commit 74b7c09
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 40 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241018-135810.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Exclude hook result from results in on-run-end context
time: 2024-10-18T13:58:10.396884-07:00
custom:
Author: ChenyuLInx
Issue: "7387"
16 changes: 12 additions & 4 deletions core/dbt/artifacts/schemas/results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union

from dbt.contracts.graph.nodes import ResultNode
from dbt_common.dataclass_schema import StrEnum, dbtClassMixin
Expand All @@ -10,7 +10,13 @@

@dataclass
class TimingInfo(dbtClassMixin):
name: str
"""
Represents a step in the execution of a node.
`name` should be one of: compile, execute, or other
Do not call directly, use `collect_timing_info` instead.
"""

name: Literal["compile", "execute", "other"]
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None

Expand All @@ -21,7 +27,7 @@ def end(self):
self.completed_at = datetime.utcnow()

def to_msg_dict(self):
msg_dict = {"name": self.name}
msg_dict = {"name": str(self.name)}
if self.started_at:
msg_dict["started_at"] = datetime_to_json_string(self.started_at)
if self.completed_at:
Expand All @@ -31,7 +37,9 @@ def to_msg_dict(self):

# This is a context manager
class collect_timing_info:
def __init__(self, name: str, callback: Callable[[TimingInfo], None]) -> None:
def __init__(
self, name: Literal["compile", "execute", "other"], callback: Callable[[TimingInfo], None]
) -> None:
self.timing_info = TimingInfo(name=name)
self.callback = callback

Expand Down
23 changes: 16 additions & 7 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RunningStatus,
RunStatus,
TimingInfo,
collect_timing_info,
)
from dbt.artifacts.schemas.run import RunResult
from dbt.cli.flags import Flags
Expand Down Expand Up @@ -633,7 +634,6 @@ def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]:
def safe_run_hooks(
self, adapter: BaseAdapter, hook_type: RunHookType, extra_context: Dict[str, Any]
) -> RunStatus:
started_at = datetime.utcnow()
ordered_hooks = self.get_hooks_by_type(hook_type)

if hook_type == RunHookType.End and ordered_hooks:
Expand All @@ -653,14 +653,20 @@ def safe_run_hooks(
hook.index = idx
hook_name = f"{hook.package_name}.{hook_type}.{hook.index - 1}"
execution_time = 0.0
timing = []
timing: List[TimingInfo] = []
failures = 1

if not failed:
with collect_timing_info("compile", timing.append):
sql = self.get_hook_sql(
adapter, hook, hook.index, num_hooks, extra_context
)

started_at = timing[0].started_at or datetime.utcnow()
hook.update_event_status(
started_at=started_at.isoformat(), node_status=RunningStatus.Started
)
sql = self.get_hook_sql(adapter, hook, hook.index, num_hooks, extra_context)

fire_event(
LogHookStartLine(
statement=hook_name,
Expand All @@ -670,11 +676,12 @@ def safe_run_hooks(
)
)

status, message = get_execution_status(sql, adapter)
finished_at = datetime.utcnow()
with collect_timing_info("execute", timing.append):
status, message = get_execution_status(sql, adapter)

finished_at = timing[1].completed_at or datetime.utcnow()
hook.update_event_status(finished_at=finished_at.isoformat())
execution_time = (finished_at - started_at).total_seconds()
timing = [TimingInfo(hook_name, started_at, finished_at)]
failures = 0 if status == RunStatus.Success else 1

if status == RunStatus.Success:
Expand Down Expand Up @@ -767,7 +774,9 @@ def after_run(self, adapter, results) -> None:

extras = {
"schemas": list({s for _, s in database_schema_set}),
"results": results,
"results": [
r for r in results if r.thread_id != "main" or r.status == RunStatus.Error
], # exclude that didn't fail to preserve backwards compatibility
"database_schemas": list(database_schema_set),
}
with adapter.connection_named("master"):
Expand Down
46 changes: 26 additions & 20 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import threading
import traceback
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

import dbt_common.exceptions
from dbt.adapters.factory import get_adapter
from dbt.artifacts.schemas.results import RunStatus, TimingInfo
from dbt.artifacts.schemas.results import RunStatus, TimingInfo, collect_timing_info
from dbt.artifacts.schemas.run import RunResult, RunResultsArtifact
from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import HookNode
Expand Down Expand Up @@ -51,25 +51,29 @@ def _run_unsafe(self, package_name, macro_name) -> "agate.Table":
return res

def run(self) -> RunResultsArtifact:
start = datetime.utcnow()
self.compile_manifest()
timing: List[TimingInfo] = []

success = True
with collect_timing_info("compile", timing.append):
self.compile_manifest()

start = timing[0].started_at

success = True
package_name, macro_name = self._get_macro_parts()

try:
self._run_unsafe(package_name, macro_name)
except dbt_common.exceptions.DbtBaseException as exc:
fire_event(RunningOperationCaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False
except Exception as exc:
fire_event(RunningOperationUncaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False
with collect_timing_info("execute", timing.append):
try:
self._run_unsafe(package_name, macro_name)
except dbt_common.exceptions.DbtBaseException as exc:
fire_event(RunningOperationCaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False
except Exception as exc:
fire_event(RunningOperationUncaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False

end = datetime.utcnow()
end = timing[1].completed_at

macro = (
self.manifest.find_macro_by_name(macro_name, self.config.project_name, package_name)
Expand All @@ -85,10 +89,12 @@ def run(self) -> RunResultsArtifact:
f"dbt could not find a macro with the name '{macro_name}' in any package"
)

execution_time = (end - start).total_seconds() if start and end else 0.0

run_result = RunResult(
adapter_response={},
status=RunStatus.Success if success else RunStatus.Error,
execution_time=(end - start).total_seconds(),
execution_time=execution_time,
failures=0 if success else 1,
message=None,
node=HookNode(
Expand All @@ -105,13 +111,13 @@ def run(self) -> RunResultsArtifact:
original_file_path="",
),
thread_id=threading.current_thread().name,
timing=[TimingInfo(name=macro_name, started_at=start, completed_at=end)],
timing=timing,
batch_results=None,
)

results = RunResultsArtifact.from_execution_results(
generated_at=end,
elapsed_time=(end - start).total_seconds(),
generated_at=end or datetime.utcnow(),
elapsed_time=execution_time,
args={
k: v
for k, v in self.args.__dict__.items()
Expand Down
1 change: 0 additions & 1 deletion core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def execute_with_hooks(self, selected_uids: AbstractSet[str]):
self.started_at = time.time()
try:
before_run_status = self.before_run(adapter, selected_uids)

if before_run_status == RunStatus.Success or (
not get_flags().skip_nodes_if_on_run_start_fails
):
Expand Down
6 changes: 5 additions & 1 deletion schemas/dbt/run-results/v6.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@
"title": "TimingInfo",
"properties": {
"name": {
"type": "string"
"enum": [
"compile",
"execute",
"other"
]
},
"started_at": {
"anyOf": [
Expand Down
6 changes: 5 additions & 1 deletion schemas/dbt/sources/v3.json
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@
"title": "TimingInfo",
"properties": {
"name": {
"type": "string"
"enum": [
"compile",
"execute",
"other"
]
},
"started_at": {
"anyOf": [
Expand Down
82 changes: 82 additions & 0 deletions tests/functional/adapter/hooks/test_on_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def test_results(self, project, log_counts, my_model_run_status):
for result in results
if isinstance(result.node, HookNode)
] == [(id, str(status)) for id, status in expected_results if id.startswith("operation")]

for result in results:
if result.status == RunStatus.Skipped:
continue

timing_keys = [timing.name for timing in result.timing]
assert timing_keys == ["compile", "execute"]

assert log_counts in log_output
assert "4 project hooks, 1 view model" in log_output

Expand Down Expand Up @@ -160,3 +168,77 @@ def test_results(self, project):

run_results = get_artifact(project.project_root, "target", "run_results.json")
assert run_results["results"] == []


class Test__HookContext__HookSuccess:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"on-run-start": [
"select 1 as id", # success
"select 1 as id", # success
],
"on-run-end": [
'{{ log("Num Results in context: " ~ results|length)}}'
"{{ output_thread_ids(results) }}",
],
}

@pytest.fixture(scope="class")
def macros(self):
return {
"log.sql": """
{% macro output_thread_ids(results) %}
{% for result in results %}
{{ log("Thread ID: " ~ result.thread_id) }}
{% endfor %}
{% endmacro %}
"""
}

@pytest.fixture(scope="class")
def models(self):
return {"my_model.sql": "select 1"}

def test_results_in_context_success(self, project):
results, log_output = run_dbt_and_capture(["--debug", "run"])
assert "Thread ID: " in log_output
assert "Thread ID: main" not in log_output
assert results[0].thread_id == "main" # hook still exists in run results
assert "Num Results in context: 1" in log_output # only model given hook was successful


class Test__HookContext__HookFail:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"on-run-start": [
"select a as id", # fail
],
"on-run-end": [
'{{ log("Num Results in context: " ~ results|length)}}'
"{{ output_thread_ids(results) }}",
],
}

@pytest.fixture(scope="class")
def macros(self):
return {
"log.sql": """
{% macro output_thread_ids(results) %}
{% for result in results %}
{{ log("Thread ID: " ~ result.thread_id) }}
{% endfor %}
{% endmacro %}
"""
}

@pytest.fixture(scope="class")
def models(self):
return {"my_model.sql": "select 1"}

def test_results_in_context_hook_fail(self, project):
results, log_output = run_dbt_and_capture(["--debug", "run"], expect_pass=False)
assert "Thread ID: main" in log_output
assert results[0].thread_id == "main"
assert "Num Results in context: 2" in log_output # failed hook and model
7 changes: 2 additions & 5 deletions tests/functional/microbatch/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,8 @@ def test_use_custom_microbatch_strategy_env_var_true_invalid_incremental_strateg
with mock.patch.object(
type(project.adapter), "valid_incremental_strategies", lambda _: []
):
# Initial run
with patch_microbatch_end_time("2020-01-03 13:57:00"):
run_dbt(["run"])

# Incremental run fails
# Run of microbatch model while adapter doesn't have a "valid"
# microbatch strategy causes an error to be raised
with patch_microbatch_end_time("2020-01-03 13:57:00"):
_, logs = run_dbt_and_capture(["run"], expect_pass=False)
assert "'microbatch' is not valid" in logs
Expand Down
17 changes: 17 additions & 0 deletions tests/functional/run_operations/test_run_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import yaml

from dbt.artifacts.schemas.results import RunStatus
from dbt.tests.util import (
check_table_does_exist,
mkdir,
Expand Down Expand Up @@ -135,9 +136,25 @@ def test_run_operation_local_macro(self, project):
run_dbt(["deps"])

results, log_output = run_dbt_and_capture(["run-operation", "something_cool"])

for result in results:
if result.status == RunStatus.Skipped:
continue

timing_keys = [timing.name for timing in result.timing]
assert timing_keys == ["compile", "execute"]

assert "something cool" in log_output

results, log_output = run_dbt_and_capture(["run-operation", "pkg.something_cool"])

for result in results:
if result.status == RunStatus.Skipped:
continue

timing_keys = [timing.name for timing in result.timing]
assert timing_keys == ["compile", "execute"]

assert "something cool" in log_output

rm_dir("pkg")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def test_all_serializable(self):


def test_date_serialization():
ti = TimingInfo("test")
ti = TimingInfo("compile")
ti.begin()
ti.end()
ti_dict = ti.to_dict()
Expand Down

0 comments on commit 74b7c09

Please sign in to comment.