Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add postflight to handle click exc and exit codes #7212

Merged
merged 3 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230321-141804.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Add exception handling in postflight decorator to address exit codes
time: 2023-03-21T14:18:04.917329-05:00
custom:
Author: stu-k
Issue: "7010"
20 changes: 20 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def invoke(self, args: List[str]) -> Tuple[Optional[List], bool]:
"callbacks": self.callbacks,
}
return cli.invoke(dbt_ctx)
except requires.HandledExit as e:
return (e.result, e.success)
except requires.UnhandledExit as e:
raise e.exception
except click.exceptions.Exit as e:
# 0 exit code, expected for --version early exit
if str(e) == "0":
Expand Down Expand Up @@ -138,6 +142,7 @@ def cli(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def build(ctx, **kwargs):
"""Run all Seeds, Models, Snapshots, and tests in DAG order"""
task = BuildTask(
Expand All @@ -162,6 +167,7 @@ def build(ctx, **kwargs):
@requires.preflight
@requires.unset_profile
@requires.project
@requires.postflight
def clean(ctx, **kwargs):
"""Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)"""
task = CleanTask(ctx.obj["flags"], ctx.obj["project"])
Expand Down Expand Up @@ -204,6 +210,7 @@ def docs(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write=False)
@requires.postflight
def docs_generate(ctx, **kwargs):
"""Generate the documentation website for your project"""
task = GenerateTask(
Expand Down Expand Up @@ -232,6 +239,7 @@ def docs_generate(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def docs_serve(ctx, **kwargs):
"""Serve the documentation website for your project"""
task = ServeTask(
Expand Down Expand Up @@ -275,6 +283,7 @@ def docs_serve(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def compile(ctx, **kwargs):
"""Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the
target/ directory."""
Expand All @@ -300,6 +309,7 @@ def compile(ctx, **kwargs):
@p.vars
@p.version_check
@requires.preflight
@requires.postflight
def debug(ctx, **kwargs):
"""Show some helpful information about dbt for debugging. Not to be confused with the --debug option which increases verbosity."""
task = DebugTask(
Expand All @@ -323,6 +333,7 @@ def debug(ctx, **kwargs):
@requires.preflight
@requires.unset_profile
@requires.project
@requires.postflight
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
task = DepsTask(ctx.obj["flags"], ctx.obj["project"])
Expand All @@ -343,6 +354,7 @@ def deps(ctx, **kwargs):
@p.target
@p.vars
@requires.preflight
@requires.postflight
def init(ctx, **kwargs):
"""Initialize a new dbt project."""
task = InitTask(ctx.obj["flags"], None)
Expand Down Expand Up @@ -375,6 +387,7 @@ def init(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def list(ctx, **kwargs):
"""List the resources in your project"""
task = ListTask(
Expand Down Expand Up @@ -412,6 +425,7 @@ def list(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write_perf_info=True)
@requires.postflight
def parse(ctx, **kwargs):
"""Parses the project and provides information on performance"""
# manifest generation and writing happens in @requires.manifest
Expand Down Expand Up @@ -445,6 +459,7 @@ def parse(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def run(ctx, **kwargs):
"""Compile SQL and execute against the current target database."""
task = RunTask(
Expand Down Expand Up @@ -473,6 +488,7 @@ def run(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def run_operation(ctx, **kwargs):
"""Run the named macro with any supplied arguments."""
task = RunOperationTask(
Expand Down Expand Up @@ -509,6 +525,7 @@ def run_operation(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def seed(ctx, **kwargs):
"""Load data from csv files into your data warehouse."""
task = SeedTask(
Expand Down Expand Up @@ -544,6 +561,7 @@ def seed(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def snapshot(ctx, **kwargs):
"""Execute snapshots defined in your project"""
task = SnapshotTask(
Expand Down Expand Up @@ -584,6 +602,7 @@ def source(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def freshness(ctx, **kwargs):
"""check the current freshness of the project's sources"""
task = FreshnessTask(
Expand Down Expand Up @@ -631,6 +650,7 @@ def freshness(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = TestTask(
Expand Down
59 changes: 44 additions & 15 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,34 @@
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.utils import cast_dict_to_dict_of_strings, ExitCodes

from click import Context
from click.exceptions import ClickException
from functools import update_wrapper
import time
import traceback


class HandledExit(ClickException):
def __init__(self, result, success, exit_code: ExitCodes) -> None:
self.result = result
self.success = success
self.exit_code = exit_code

def show(self):
pass


class UnhandledExit(ClickException):
exit_code = ExitCodes.UnhandledError.value

def __init__(self, exception: Exception, message: str) -> None:
self.exception = exception
self.message = message

def format_message(self) -> str:
return self.message


def preflight(func):
Expand Down Expand Up @@ -61,11 +84,22 @@ def wrapper(*args, **kwargs):
# Adapter management
ctx.with_resource(adapter_management())

return func(*args, **kwargs)

return update_wrapper(wrapper, func)


def postflight(func):
def wrapper(*args, **kwargs):
ctx = args[0]
start_func = time.perf_counter()
success = False

try:
(results, success) = func(*args, **kwargs)

result, success = func(*args, **kwargs)
except Exception as e:
raise UnhandledExit(e, message=traceback.format_exc())
finally:
fire_event(
CommandCompleted(
command=ctx.command_path,
Expand All @@ -74,20 +108,15 @@ def wrapper(*args, **kwargs):
elapsed=time.perf_counter() - start_func,
)
)
# Bare except because we really do want to catch ALL exceptions,
# i.e. we want to fire this event in ALL cases.
except: # noqa
fire_event(
CommandCompleted(
command=ctx.command_path,
success=False,
completed_at=get_json_string_utcnow(),
elapsed=time.perf_counter() - start_func,
)

if not success:
raise HandledExit(
result=result,
success=success,
exit_code=ExitCodes.ModelError.value,
)
raise

return (results, success)
return (result, success)

return update_wrapper(wrapper, func)

Expand Down
38 changes: 38 additions & 0 deletions tests/functional/cli/test_cli_exit_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from dbt.cli.main import cli
from dbt.cli.requires import HandledExit


good_sql = """
select 1 as fun
"""

bad_sql = """
someting bad
"""


class CliRunnerBase:
def run_cli(self):
ctx = cli.make_context(cli.name, ["run"])
return cli.invoke(ctx)


class TestExitCodeZero(CliRunnerBase):
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": good_sql}

def test_no_exc_thrown(self, project):
self.run_cli()


class TestExitCodeOne(CliRunnerBase):
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": bad_sql}

def test_exc_thrown(self, project):
stu-k marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(HandledExit):
self.run_cli()
17 changes: 17 additions & 0 deletions tests/functional/cli/test_error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from dbt.tests.util import run_dbt


model_one_sql = """
someting bad
"""


class TestHandledExit:
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": model_one_sql}

def test_failed_run_does_not_throw(self, project):
run_dbt(["run"], expect_pass=False)