diff --git a/.changes/unreleased/Fixes-20230321-141804.yaml b/.changes/unreleased/Fixes-20230321-141804.yaml new file mode 100644 index 00000000000..2f3f8e0472d --- /dev/null +++ b/.changes/unreleased/Fixes-20230321-141804.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Add exception handling in postflight decorator to address exit codes +time: 2023-03-21T14:18:04.917329-05:00 +custom: + Author: stu-k + Issue: "7010" diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index d65332ca6d0..19dc8d327cb 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -56,6 +56,10 @@ def invoke(self, args: List[str]) -> Tuple[Optional[List], bool]: "callbacks": self.callbacks, } return cli.invoke(dbt_ctx) + except requires.HandledExit as e: + return (e.result, e.success) + except requires.UnhandledExit as e: + raise e.exception except click.exceptions.Exit as e: # 0 exit code, expected for --version early exit if str(e) == "0": @@ -138,6 +142,7 @@ def cli(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def build(ctx, **kwargs): """Run all Seeds, Models, Snapshots, and tests in DAG order""" task = BuildTask( @@ -162,6 +167,7 @@ def build(ctx, **kwargs): @requires.preflight @requires.unset_profile @requires.project +@requires.postflight def clean(ctx, **kwargs): """Delete all folders in the clean-targets list (usually the dbt_packages and target directories.)""" task = CleanTask(ctx.obj["flags"], ctx.obj["project"]) @@ -204,6 +210,7 @@ def docs(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest(write=False) +@requires.postflight def docs_generate(ctx, **kwargs): """Generate the documentation website for your project""" task = GenerateTask( @@ -232,6 +239,7 @@ def docs_generate(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def docs_serve(ctx, **kwargs): """Serve the documentation website for your project""" task = ServeTask( @@ -275,6 +283,7 @@ def docs_serve(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def compile(ctx, **kwargs): """Generates executable SQL from source, model, test, and analysis files. Compiled SQL files are written to the target/ directory.""" @@ -300,6 +309,7 @@ def compile(ctx, **kwargs): @p.vars @p.version_check @requires.preflight +@requires.postflight def debug(ctx, **kwargs): """Show some helpful information about dbt for debugging. Not to be confused with the --debug option which increases verbosity.""" task = DebugTask( @@ -323,6 +333,7 @@ def debug(ctx, **kwargs): @requires.preflight @requires.unset_profile @requires.project +@requires.postflight def deps(ctx, **kwargs): """Pull the most recent version of the dependencies listed in packages.yml""" task = DepsTask(ctx.obj["flags"], ctx.obj["project"]) @@ -343,6 +354,7 @@ def deps(ctx, **kwargs): @p.target @p.vars @requires.preflight +@requires.postflight def init(ctx, **kwargs): """Initialize a new dbt project.""" task = InitTask(ctx.obj["flags"], None) @@ -375,6 +387,7 @@ def init(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def list(ctx, **kwargs): """List the resources in your project""" task = ListTask( @@ -412,6 +425,7 @@ def list(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest(write_perf_info=True) +@requires.postflight def parse(ctx, **kwargs): """Parses the project and provides information on performance""" # manifest generation and writing happens in @requires.manifest @@ -445,6 +459,7 @@ def parse(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def run(ctx, **kwargs): """Compile SQL and execute against the current target database.""" task = RunTask( @@ -473,6 +488,7 @@ def run(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def run_operation(ctx, **kwargs): """Run the named macro with any supplied arguments.""" task = RunOperationTask( @@ -509,6 +525,7 @@ def run_operation(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def seed(ctx, **kwargs): """Load data from csv files into your data warehouse.""" task = SeedTask( @@ -544,6 +561,7 @@ def seed(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def snapshot(ctx, **kwargs): """Execute snapshots defined in your project""" task = SnapshotTask( @@ -584,6 +602,7 @@ def source(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def freshness(ctx, **kwargs): """check the current freshness of the project's sources""" task = FreshnessTask( @@ -631,6 +650,7 @@ def freshness(ctx, **kwargs): @requires.project @requires.runtime_config @requires.manifest +@requires.postflight def test(ctx, **kwargs): """Runs tests on data in deployed models. Run this after `dbt run`""" task = TestTask( diff --git a/core/dbt/cli/requires.py b/core/dbt/cli/requires.py index fb68d753c26..33e4f919cbc 100644 --- a/core/dbt/cli/requires.py +++ b/core/dbt/cli/requires.py @@ -16,11 +16,34 @@ from dbt.parser.manifest import ManifestLoader, write_manifest from dbt.profiler import profiler from dbt.tracking import active_user, initialize_from_flags, track_run -from dbt.utils import cast_dict_to_dict_of_strings +from dbt.utils import cast_dict_to_dict_of_strings, ExitCodes from click import Context +from click.exceptions import ClickException from functools import update_wrapper import time +import traceback + + +class HandledExit(ClickException): + def __init__(self, result, success, exit_code: ExitCodes) -> None: + self.result = result + self.success = success + self.exit_code = exit_code + + def show(self): + pass + + +class UnhandledExit(ClickException): + exit_code = ExitCodes.UnhandledError.value + + def __init__(self, exception: Exception, message: str) -> None: + self.exception = exception + self.message = message + + def format_message(self) -> str: + return self.message def preflight(func): @@ -61,11 +84,22 @@ def wrapper(*args, **kwargs): # Adapter management ctx.with_resource(adapter_management()) + return func(*args, **kwargs) + + return update_wrapper(wrapper, func) + + +def postflight(func): + def wrapper(*args, **kwargs): + ctx = args[0] start_func = time.perf_counter() + success = False try: - (results, success) = func(*args, **kwargs) - + result, success = func(*args, **kwargs) + except Exception as e: + raise UnhandledExit(e, message=traceback.format_exc()) + finally: fire_event( CommandCompleted( command=ctx.command_path, @@ -74,20 +108,15 @@ def wrapper(*args, **kwargs): elapsed=time.perf_counter() - start_func, ) ) - # Bare except because we really do want to catch ALL exceptions, - # i.e. we want to fire this event in ALL cases. - except: # noqa - fire_event( - CommandCompleted( - command=ctx.command_path, - success=False, - completed_at=get_json_string_utcnow(), - elapsed=time.perf_counter() - start_func, - ) + + if not success: + raise HandledExit( + result=result, + success=success, + exit_code=ExitCodes.ModelError.value, ) - raise - return (results, success) + return (result, success) return update_wrapper(wrapper, func) diff --git a/tests/functional/cli/test_cli_exit_codes.py b/tests/functional/cli/test_cli_exit_codes.py new file mode 100644 index 00000000000..e3067e5e42e --- /dev/null +++ b/tests/functional/cli/test_cli_exit_codes.py @@ -0,0 +1,38 @@ +import pytest + +from dbt.cli.main import cli +from dbt.cli.requires import HandledExit + + +good_sql = """ +select 1 as fun +""" + +bad_sql = """ +someting bad +""" + + +class CliRunnerBase: + def run_cli(self): + ctx = cli.make_context(cli.name, ["run"]) + return cli.invoke(ctx) + + +class TestExitCodeZero(CliRunnerBase): + @pytest.fixture(scope="class") + def models(self): + return {"model_one.sql": good_sql} + + def test_no_exc_thrown(self, project): + self.run_cli() + + +class TestExitCodeOne(CliRunnerBase): + @pytest.fixture(scope="class") + def models(self): + return {"model_one.sql": bad_sql} + + def test_exc_thrown(self, project): + with pytest.raises(HandledExit): + self.run_cli() diff --git a/tests/functional/cli/test_error_handling.py b/tests/functional/cli/test_error_handling.py new file mode 100644 index 00000000000..26bafaa1c1a --- /dev/null +++ b/tests/functional/cli/test_error_handling.py @@ -0,0 +1,17 @@ +import pytest + +from dbt.tests.util import run_dbt + + +model_one_sql = """ +someting bad +""" + + +class TestHandledExit: + @pytest.fixture(scope="class") + def models(self): + return {"model_one.sql": model_one_sql} + + def test_failed_run_does_not_throw(self, project): + run_dbt(["run"], expect_pass=False)