From 191c635b45cec87802606c8608e81920c7313794 Mon Sep 17 00:00:00 2001 From: kesompochy <95553894+kesompochy@users.noreply.github.com> Date: Tue, 29 Oct 2024 20:53:57 +0900 Subject: [PATCH] Fix DbtVirtualenvBaseOperator to use correct virtualenv Python path (#1252) This PR addresses an issue where the `DbtVirtualenvBaseOperator` was executing dbt commands using the system-wide Python path instead of the virtualenv path. The root cause was that the self reference in the `run_subprocess` method was bound to a different instance than the one created during initialization, likely due to Airflow's DAG pickling mechanism. To resolve this, we've refactored the `invoke_dbt` and `handle_exception` methods to be properties. This ensures that they dynamically reference the correct method of the current instance at runtime, rather than being bound to a potentially stale instance from initialization. ## Related Issue(s) fix https://github.com/astronomer/astronomer-cosmos/issues/1246 This may be related to issue https://github.com/astronomer/astronomer-cosmos/issues/958 in version 1.5.0 ## Breaking Change? No ## Checklist - [x] I have made corresponding changes to the documentation (if required) - [ ] I have added tests that prove my fix is effective or that my feature works ## Additional Notes I acknowledge that ideally, a test should be added to reproduce the original issue and verify the fix. However, I found it challenging to create an appropriate test, especially considering that this might require an integration test with Airflow to properly simulate Airflow behavior. If there are suggestions for how to effectively test this scenario, I would greatly appreciate the guidance. I sincerely apologize for introducing this bug in the first place with PR https://github.com/astronomer/astronomer-cosmos/pull/1200. I understand this has caused inconvenience, and I'm grateful for the opportunity to fix it. I kindly request a thorough review of these changes to ensure we've fully addressed the issue without introducing new problems. --- cosmos/operators/local.py | 26 +++++++++++++++----------- cosmos/operators/virtualenv.py | 2 +- tests/operators/test_local.py | 17 ++++++++++++++++- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 05fa356f6..618d9e944 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -154,11 +154,7 @@ def __init__( self.should_upload_compiled_sql = should_upload_compiled_sql self.openlineage_events_completes: list[RunEvent] = [] self.invocation_mode = invocation_mode - self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] - self.handle_exception: Callable[..., None] self._dbt_runner: dbtRunner | None = None - if self.invocation_mode: - self._set_invocation_methods() if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): from airflow.datasets import DatasetAlias @@ -187,14 +183,23 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() - def _set_invocation_methods(self) -> None: - """Sets the associated run and exception handling methods based on the invocation mode.""" + @property + def invoke_dbt(self) -> Callable[..., FullOutputSubprocessResult | dbtRunnerResult]: if self.invocation_mode == InvocationMode.SUBPROCESS: - self.invoke_dbt = self.run_subprocess - self.handle_exception = self.handle_exception_subprocess + return self.run_subprocess elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.invoke_dbt = self.run_dbt_runner - self.handle_exception = self.handle_exception_dbt_runner + return self.run_dbt_runner + else: + raise ValueError(f"Invalid invocation mode: {self.invocation_mode}") + + @property + def handle_exception(self) -> Callable[..., None]: + if self.invocation_mode == InvocationMode.SUBPROCESS: + return self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + return self.handle_exception_dbt_runner + else: + raise ValueError(f"Invalid invocation mode: {self.invocation_mode}") def _discover_invocation_mode(self) -> None: """Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will @@ -209,7 +214,6 @@ def _discover_invocation_mode(self) -> None: else: self.invocation_mode = InvocationMode.DBT_RUNNER self.log.info("dbtRunner is available. Using dbtRunner for invoking dbt.") - self._set_invocation_methods() def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index ebc8bcf3b..ac3d11799 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -83,10 +83,10 @@ def __init__( self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter") def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: - self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd) if self._py_bin is not None: self.log.info(f"Using Python binary from virtualenv: {self._py_bin}") command[0] = str(Path(self._py_bin).parent / "dbt") + self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd) subprocess_result = self.subprocess_hook.run_command( command=command, env=env, diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index ed954dfdf..2de6ca1e3 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -175,11 +175,26 @@ def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_me dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode ) - dbt_base_operator._set_invocation_methods() assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method assert dbt_base_operator.handle_exception.__name__ == handle_exception_method +def test_dbt_base_operator_invalid_invocation_mode(): + """Tests that an invalid invocation_mode raises a ValueError when invoke_dbt or handle_exception is accessed.""" + operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + ) + operator.invocation_mode = "invalid_mode" + + with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"): + _ = operator.invoke_dbt + + with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"): + _ = operator.handle_exception + + @pytest.mark.parametrize( "can_import_dbt, invoke_dbt_method, handle_exception_method", [