Skip to content

Commit

Permalink
Backport 9328 to 1.7.latest (#9391)
Browse files Browse the repository at this point in the history
* Fix full-refresh and vars for retry (#9328)

Co-authored-by: Peter Allen Webb <[email protected]>
(cherry picked from commit 1e4286a)

* pr feedback

* Update requires.py
  • Loading branch information
ChenyuLInx authored Jan 17, 2024
1 parent d338b3e commit 6e33183
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 67 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231213-220449.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Preserve the value of vars and the --full-refresh flags when using retry.
time: 2023-12-13T22:04:49.228294-05:00
custom:
Author: peterallenwebb, ChenyuLInx
Issue: "9112"
8 changes: 4 additions & 4 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FLAGS_DEFAULTS = {
"INDIRECT_SELECTION": "eager",
"TARGET_PATH": None,
"WARN_ERROR": None,
# Cli args without user_config or env var option.
"FULL_REFRESH": False,
"STRICT_MODE": False,
Expand Down Expand Up @@ -78,7 +79,6 @@ class Flags:
def __init__(
self, ctx: Optional[Context] = None, user_config: Optional[UserConfig] = None
) -> None:

# Set the default flags.
for key, value in FLAGS_DEFAULTS.items():
object.__setattr__(self, key, value)
Expand Down Expand Up @@ -120,7 +120,6 @@ def _assign_params(
# respected over DBT_PRINT or --print.
new_name: Union[str, None] = None
if param_name in DEPRECATED_PARAMS:

# Deprecated env vars can only be set via env var.
# We use the deprecated option in click to serialize the value
# from the env var string.
Expand Down Expand Up @@ -315,7 +314,6 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
default_args = set([x.lower() for x in FLAGS_DEFAULTS.keys()])

res = command.to_list()

for k, v in args_dict.items():
k = k.lower()
# if a "which" value exists in the args dict, it should match the command provided
Expand All @@ -327,7 +325,9 @@ def command_params(command: CliCommand, args_dict: Dict[str, Any]) -> CommandPar
continue

# param was assigned from defaults and should not be included
if k not in (cmd_args | prnt_args) - default_args:
if k not in (cmd_args | prnt_args) or (
k in default_args and v == FLAGS_DEFAULTS[k.upper()]
):
continue

# if the param is in parent args, it should come before the arg name
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,12 @@ def run(ctx, **kwargs):
@p.target
@p.state
@p.threads
@p.full_refresh
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
Expand Down
22 changes: 7 additions & 15 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from dbt.events.helpers import get_json_string_utcnow
from dbt.events.types import MainEncounteredError, MainStackTrace
from dbt.exceptions import Exception as DbtException, DbtProjectError, FailFastError
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.parser.manifest import parse_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.plugins import set_up_plugin_manager, get_plugin_manager
from dbt.plugins import set_up_plugin_manager


from click import Context
from functools import update_wrapper
Expand Down Expand Up @@ -264,23 +265,14 @@ def wrapper(*args, **kwargs):
raise DbtProjectError("profile, project, and runtime_config required for manifest")

runtime_config = ctx.obj["runtime_config"]
register_adapter(runtime_config)

# a manifest has already been set on the context, so don't overwrite it
if ctx.obj.get("manifest") is None:
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
ctx.obj["manifest"] = parse_manifest(
runtime_config, write_perf_info, write, ctx.obj["flags"].write_json
)

ctx.obj["manifest"] = manifest
if write and ctx.obj["flags"].write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)

else:
register_adapter(runtime_config)
return func(*args, **kwargs)

return update_wrapper(wrapper, func)
Expand Down
17 changes: 11 additions & 6 deletions core/dbt/contracts/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from dbt.exceptions import IncompatibleSchemaError


def load_result_state(results_path) -> Optional[RunResultsArtifact]:
if results_path.exists() and results_path.is_file():
try:
return RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
return None


class PreviousState:
def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> None:
self.state_path: Path = state_path
Expand All @@ -32,12 +42,7 @@ def __init__(self, state_path: Path, target_path: Path, project_root: Path) -> N
raise

results_path = self.project_root / self.state_path / "run_results.json"
if results_path.exists() and results_path.is_file():
try:
self.results = RunResultsArtifact.read_and_check_versions(str(results_path))
except IncompatibleSchemaError as exc:
exc.add_filename(str(results_path))
raise
self.results = load_result_state(results_path)

sources_path = self.project_root / self.state_path / "sources.json"
if sources_path.exists() and sources_path.is_file():
Expand Down
21 changes: 17 additions & 4 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
register_adapter,
)
from dbt.constants import (
MANIFEST_FILE_NAME,
Expand Down Expand Up @@ -278,7 +279,6 @@ def get_full_manifest(
reset: bool = False,
write_perf_info=False,
) -> Manifest:

adapter = get_adapter(config) # type: ignore
# reset is set in a TaskManager load_manifest call, since
# the config and adapter may be persistent.
Expand Down Expand Up @@ -590,7 +590,6 @@ def check_for_model_deprecations(self):
node.depends_on
for resolved_ref in resolved_model_refs:
if resolved_ref.deprecation_date:

if resolved_ref.deprecation_date < datetime.datetime.now().astimezone():
event_cls = DeprecatedReference
else:
Expand Down Expand Up @@ -1733,7 +1732,6 @@ def _process_sources_for_metric(manifest: Manifest, current_project: str, metric


def _process_sources_for_node(manifest: Manifest, current_project: str, node: ManifestNode):

if isinstance(node, SeedNode):
return

Expand Down Expand Up @@ -1775,7 +1773,6 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No
# This is called in task.rpc.sql_commands when a "dynamic" node is
# created in the manifest, in 'add_refs'
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):

_process_sources_for_node(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node, config.dependencies)
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
Expand All @@ -1793,3 +1790,19 @@ def write_manifest(manifest: Manifest, target_path: str):
manifest.write(path)

write_semantic_manifest(manifest=manifest, target_path=target_path)


def parse_manifest(runtime_config, write_perf_info, write, write_json):
register_adapter(runtime_config)
manifest = ManifestLoader.get_full_manifest(
runtime_config,
write_perf_info=write_perf_info,
)

if write and write_json:
write_manifest(manifest, runtime_config.project_target_path)
pm = plugins.get_plugin_manager(runtime_config.project_name)
plugin_artifacts = pm.get_manifest_artifacts(manifest)
for path, plugin_artifact in plugin_artifacts.items():
plugin_artifact.write(path)
return manifest
84 changes: 48 additions & 36 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path
from click import get_current_context
from click.core import ParameterSource

from dbt.cli.flags import Flags
from dbt.flags import set_flags, get_flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.contracts.state import load_result_state
from dbt.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
Expand All @@ -17,9 +20,10 @@
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.parser.manifest import parse_manifest

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
OVERRIDE_PARENT_FLAGS = {
IGNORE_PARENT_FLAGS = {
"log_path",
"output_path",
"profiles_dir",
Expand All @@ -28,8 +32,11 @@
"defer_state",
"deprecated_state",
"target_path",
"warn_error",
}

ALLOW_CLI_OVERRIDE_FLAGS = {"vars"}

TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
Expand Down Expand Up @@ -57,59 +64,64 @@

class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest) -> None:
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

self.previous_state = PreviousState(
state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
# load previous run results
state_path = args.state or config.target_path
self.previous_results = load_result_state(
Path(config.project_root) / Path(state_path) / "run_results.json"
)

if not self.previous_state.results:
if not self.previous_results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_args = self.previous_results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)
# Reslove flags and config
if args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)

cli_command = CMD_DICT.get(self.previous_command_name) # type: ignore
# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

previous_args = {
k: v for k, v in self.previous_args.items() if k not in OVERRIDE_PARENT_FLAGS
k: v for k, v in self.previous_args.items() if k not in IGNORE_PARENT_FLAGS
}
click_context = get_current_context()
current_args = {
k: v
for k, v in args.__dict__.items()
if k in IGNORE_PARENT_FLAGS
or (
click_context.get_parameter_source(k) == ParameterSource.COMMANDLINE
and k in ALLOW_CLI_OVERRIDE_FLAGS
)
}
current_args = {k: v for k, v in self.args.__dict__.items() if k in OVERRIDE_PARENT_FLAGS}
combined_args = {**previous_args, **current_args}

retry_flags = Flags.from_dict(cli_command, combined_args)
retry_flags = Flags.from_dict(cli_command, combined_args) # type: ignore
set_flags(retry_flags)
retry_config = RuntimeConfig.from_args(args=retry_flags)

# Parse manifest using resolved config/flags
manifest = parse_manifest(retry_config, False, True, retry_flags.write_json) # type: ignore
super().__init__(args, retry_config, manifest)
self.task_class = TASK_DICT.get(self.previous_command_name) # type: ignore

def run(self):
unique_ids = set(
[
result.unique_id
for result in self.previous_results.results
if result.status in RETRYABLE_STATUSES
]
)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
Expand All @@ -120,8 +132,8 @@ def get_graph_queue(self):
)

task = TaskWrapper(
retry_flags,
retry_config,
get_flags(),
self.config,
self.manifest,
)

Expand Down
38 changes: 37 additions & 1 deletion tests/functional/retry/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def test_previous_run(self, project):
write_file(models__sample_model, "models", "sample_model.sql")

def test_warn_error(self, project):
# Regular build
# Our test command should succeed when run normally...
results = run_dbt(["build", "--select", "second_model"])

# ...but it should fail when run with warn-error, due to a warning...
results = run_dbt(["--warn-error", "build", "--select", "second_model"], expect_pass=False)

expected_statuses = {
Expand Down Expand Up @@ -291,3 +294,36 @@ def test_retry(self, project):
run_dbt(["run", "--project-dir", "proj_location_1"], expect_pass=False)
move(proj_location_1, proj_location_2)
run_dbt(["retry", "--project-dir", "proj_location_2"], expect_pass=False)


class TestRetryVars:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": "select {{ var('myvar_a', '1') + var('myvar_b', '2') }} as mycol",
}

def test_retry(self, project):
# pass because default vars works
run_dbt(["run"])
run_dbt(["run", "--vars", '{"myvar_a": "12", "myvar_b": "3 4"}'], expect_pass=False)
# fail because vars are invalid, this shows that the last passed vars are being used
# instead of using the default vars
run_dbt(["retry"], expect_pass=False)
results = run_dbt(["retry", "--vars", '{"myvar_a": "12", "myvar_b": "34"}'])
assert len(results) == 1


class TestRetryFullRefresh:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": "{% if flags.FULL_REFRESH %} this is invalid sql {% else %} select 1 as mycol {% endif %}",
}

def test_retry(self, project):
# This run should fail with invalid sql...
run_dbt(["run", "--full-refresh"], expect_pass=False)
# ...and so should this one, since the effect of the full-refresh parameter should persist.
results = run_dbt(["retry"], expect_pass=False)
assert len(results) == 1

0 comments on commit 6e33183

Please sign in to comment.