diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 05fa356f6..618d9e944 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -154,11 +154,7 @@ def __init__( self.should_upload_compiled_sql = should_upload_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] self.invocation_mode = invocation_mode - self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] - self.handle_exception: Callable[..., None] self._dbt_runner: dbtRunner | None = None - if self.invocation_mode: - self._set_invocation_methods() if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): from airflow.datasets import DatasetAlias @@ -187,14 +183,23 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() - def _set_invocation_methods(self) -> None: - """Sets the associated run and exception handling methods based on the invocation mode.""" + @property + def invoke_dbt(self) -> Callable[..., FullOutputSubprocessResult | dbtRunnerResult]: if self.invocation_mode == InvocationMode.SUBPROCESS: - self.invoke_dbt = self.run_subprocess - self.handle_exception = self.handle_exception_subprocess + return self.run_subprocess elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.invoke_dbt = self.run_dbt_runner - self.handle_exception = self.handle_exception_dbt_runner + return self.run_dbt_runner + else: + raise ValueError(f"Invalid invocation mode: {self.invocation_mode}") + + @property + def handle_exception(self) -> Callable[..., None]: + if self.invocation_mode == InvocationMode.SUBPROCESS: + return self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + return self.handle_exception_dbt_runner + else: + raise ValueError(f"Invalid invocation mode: {self.invocation_mode}") def _discover_invocation_mode(self) -> None: """Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will @@ -209,7 +214,6 @@ def _discover_invocation_mode(self) -> None: else: self.invocation_mode = InvocationMode.DBT_RUNNER self.log.info("dbtRunner is available. Using dbtRunner for invoking dbt.") - self._set_invocation_methods() def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index ebc8bcf3b..ac3d11799 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -83,10 +83,10 @@ def __init__( self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter") def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: - self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd) if self._py_bin is not None: self.log.info(f"Using Python binary from virtualenv: {self._py_bin}") command[0] = str(Path(self._py_bin).parent / "dbt") + self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd) subprocess_result = self.subprocess_hook.run_command( command=command, env=env, diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index ed954dfdf..2de6ca1e3 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -175,11 +175,26 @@ def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_me dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode ) - dbt_base_operator._set_invocation_methods() assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method assert dbt_base_operator.handle_exception.__name__ == handle_exception_method +def test_dbt_base_operator_invalid_invocation_mode(): + """Tests that an invalid invocation_mode raises a ValueError when invoke_dbt or handle_exception is accessed.""" + operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + ) + operator.invocation_mode = "invalid_mode" + + with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"): + _ = operator.invoke_dbt + + with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"): + _ = operator.handle_exception + + @pytest.mark.parametrize( "can_import_dbt, invoke_dbt_method, handle_exception_method", [