Skip to content

Commit

Permalink
Add postflight to handle click exc and exit codes (#7212)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu-k authored Mar 27, 2023
1 parent bf5ed39 commit 6992151
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 15 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230321-141804.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Add exception handling in postflight decorator to address exit codes
time: 2023-03-21T14:18:04.917329-05:00
custom:
Author: stu-k
Issue: "7010"
20 changes: 20 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def invoke(self, args: List[str]) -> Tuple[Optional[List], bool]:
"callbacks": self.callbacks,
}
return cli.invoke(dbt_ctx)
except requires.HandledExit as e:
return (e.result, e.success)
except requires.UnhandledExit as e:
raise e.exception
except click.exceptions.Exit as e:
# 0 exit code, expected for --version early exit
if str(e) == "0":
Expand Down Expand Up @@ -138,6 +142,7 @@ def cli(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def build(ctx, **kwargs):
"""Run all Seeds, Models, Snapshots, and tests in DAG order"""
task = BuildTask(
Expand All @@ -162,6 +167,7 @@ def build(ctx, **kwargs):
@requires.preflight
@requires.unset_profile
@requires.project
@requires.postflight
def clean(ctx, **kwargs):
"""Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)"""
task = CleanTask(ctx.obj["flags"], ctx.obj["project"])
Expand Down Expand Up @@ -204,6 +210,7 @@ def docs(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write=False)
@requires.postflight
def docs_generate(ctx, **kwargs):
"""Generate the documentation website for your project"""
task = GenerateTask(
Expand Down Expand Up @@ -232,6 +239,7 @@ def docs_generate(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def docs_serve(ctx, **kwargs):
"""Serve the documentation website for your project"""
task = ServeTask(
Expand Down Expand Up @@ -275,6 +283,7 @@ def docs_serve(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def compile(ctx, **kwargs):
"""Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the
target/ directory."""
Expand All @@ -300,6 +309,7 @@ def compile(ctx, **kwargs):
@p.vars
@p.version_check
@requires.preflight
@requires.postflight
def debug(ctx, **kwargs):
"""Show some helpful information about dbt for debugging. Not to be confused with the --debug option which increases verbosity."""
task = DebugTask(
Expand All @@ -323,6 +333,7 @@ def debug(ctx, **kwargs):
@requires.preflight
@requires.unset_profile
@requires.project
@requires.postflight
def deps(ctx, **kwargs):
"""Pull the most recent version of the dependencies listed in packages.yml"""
task = DepsTask(ctx.obj["flags"], ctx.obj["project"])
Expand All @@ -343,6 +354,7 @@ def deps(ctx, **kwargs):
@p.target
@p.vars
@requires.preflight
@requires.postflight
def init(ctx, **kwargs):
"""Initialize a new dbt project."""
task = InitTask(ctx.obj["flags"], None)
Expand Down Expand Up @@ -375,6 +387,7 @@ def init(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def list(ctx, **kwargs):
"""List the resources in your project"""
task = ListTask(
Expand Down Expand Up @@ -412,6 +425,7 @@ def list(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest(write_perf_info=True)
@requires.postflight
def parse(ctx, **kwargs):
"""Parses the project and provides information on performance"""
# manifest generation and writing happens in @requires.manifest
Expand Down Expand Up @@ -445,6 +459,7 @@ def parse(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def run(ctx, **kwargs):
"""Compile SQL and execute against the current target database."""
task = RunTask(
Expand Down Expand Up @@ -473,6 +488,7 @@ def run(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def run_operation(ctx, **kwargs):
"""Run the named macro with any supplied arguments."""
task = RunOperationTask(
Expand Down Expand Up @@ -509,6 +525,7 @@ def run_operation(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def seed(ctx, **kwargs):
"""Load data from csv files into your data warehouse."""
task = SeedTask(
Expand Down Expand Up @@ -544,6 +561,7 @@ def seed(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def snapshot(ctx, **kwargs):
"""Execute snapshots defined in your project"""
task = SnapshotTask(
Expand Down Expand Up @@ -584,6 +602,7 @@ def source(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def freshness(ctx, **kwargs):
"""check the current freshness of the project's sources"""
task = FreshnessTask(
Expand Down Expand Up @@ -631,6 +650,7 @@ def freshness(ctx, **kwargs):
@requires.project
@requires.runtime_config
@requires.manifest
@requires.postflight
def test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = TestTask(
Expand Down
59 changes: 44 additions & 15 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,34 @@
from dbt.parser.manifest import ManifestLoader, write_manifest
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.utils import cast_dict_to_dict_of_strings, ExitCodes

from click import Context
from click.exceptions import ClickException
from functools import update_wrapper
import time
import traceback


class HandledExit(ClickException):
def __init__(self, result, success, exit_code: ExitCodes) -> None:
self.result = result
self.success = success
self.exit_code = exit_code

def show(self):
pass


class UnhandledExit(ClickException):
exit_code = ExitCodes.UnhandledError.value

def __init__(self, exception: Exception, message: str) -> None:
self.exception = exception
self.message = message

def format_message(self) -> str:
return self.message


def preflight(func):
Expand Down Expand Up @@ -61,11 +84,22 @@ def wrapper(*args, **kwargs):
# Adapter management
ctx.with_resource(adapter_management())

return func(*args, **kwargs)

return update_wrapper(wrapper, func)


def postflight(func):
def wrapper(*args, **kwargs):
ctx = args[0]
start_func = time.perf_counter()
success = False

try:
(results, success) = func(*args, **kwargs)

result, success = func(*args, **kwargs)
except Exception as e:
raise UnhandledExit(e, message=traceback.format_exc())
finally:
fire_event(
CommandCompleted(
command=ctx.command_path,
Expand All @@ -74,20 +108,15 @@ def wrapper(*args, **kwargs):
elapsed=time.perf_counter() - start_func,
)
)
# Bare except because we really do want to catch ALL exceptions,
# i.e. we want to fire this event in ALL cases.
except: # noqa
fire_event(
CommandCompleted(
command=ctx.command_path,
success=False,
completed_at=get_json_string_utcnow(),
elapsed=time.perf_counter() - start_func,
)

if not success:
raise HandledExit(
result=result,
success=success,
exit_code=ExitCodes.ModelError.value,
)
raise

return (results, success)
return (result, success)

return update_wrapper(wrapper, func)

Expand Down
38 changes: 38 additions & 0 deletions tests/functional/cli/test_cli_exit_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from dbt.cli.main import cli
from dbt.cli.requires import HandledExit


good_sql = """
select 1 as fun
"""

bad_sql = """
someting bad
"""


class CliRunnerBase:
def run_cli(self):
ctx = cli.make_context(cli.name, ["run"])
return cli.invoke(ctx)


class TestExitCodeZero(CliRunnerBase):
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": good_sql}

def test_no_exc_thrown(self, project):
self.run_cli()


class TestExitCodeOne(CliRunnerBase):
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": bad_sql}

def test_exc_thrown(self, project):
with pytest.raises(HandledExit):
self.run_cli()
17 changes: 17 additions & 0 deletions tests/functional/cli/test_error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from dbt.tests.util import run_dbt


model_one_sql = """
someting bad
"""


class TestHandledExit:
@pytest.fixture(scope="class")
def models(self):
return {"model_one.sql": model_one_sql}

def test_failed_run_does_not_throw(self, project):
run_dbt(["run"], expect_pass=False)

0 comments on commit 6992151

Please sign in to comment.