Skip to content

Commit

Permalink
Fix DbtVirtualenvBaseOperator to use correct virtualenv Python path (#…
Browse files Browse the repository at this point in the history
…1252)

This PR addresses an issue where the `DbtVirtualenvBaseOperator` was
executing dbt commands using the system-wide Python path instead of the
virtualenv path. The root cause was that the self reference in the
`run_subprocess` method was bound to a different instance than the one
created during initialization, likely due to Airflow's DAG pickling
mechanism.
To resolve this, we've refactored the `invoke_dbt` and
`handle_exception` methods to be properties. This ensures that they
dynamically reference the correct method of the current instance at
runtime, rather than being bound to a potentially stale instance from
initialization.

## Related Issue(s)

fix #1246

This may be related to issue
#958 in version
1.5.0

## Breaking Change?

No

## Checklist

- [x] I have made corresponding changes to the documentation (if
required)
- [ ] I have added tests that prove my fix is effective or that my
feature works

## Additional Notes

I acknowledge that ideally, a test should be added to reproduce the
original issue and verify the fix. However, I found it challenging to
create an appropriate test, especially considering that this might
require an integration test with Airflow to properly simulate Airflow
behavior. If there are suggestions for how to effectively test this
scenario, I would greatly appreciate the guidance.

I sincerely apologize for introducing this bug in the first place with
PR #1200. I
understand this has caused inconvenience, and I'm grateful for the
opportunity to fix it. I kindly request a thorough review of these
changes to ensure we've fully addressed the issue without introducing
new problems.
  • Loading branch information
kesompochy authored Oct 29, 2024
1 parent 84c5fbd commit 191c635
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
26 changes: 15 additions & 11 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,7 @@ def __init__(
self.should_upload_compiled_sql = should_upload_compiled_sql
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
self.handle_exception: Callable[..., None]
self._dbt_runner: dbtRunner | None = None
if self.invocation_mode:
self._set_invocation_methods()

if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"):
from airflow.datasets import DatasetAlias
Expand Down Expand Up @@ -187,14 +183,23 @@ def subprocess_hook(self) -> FullOutputSubprocessHook:
"""Returns hook for running the bash command."""
return FullOutputSubprocessHook()

def _set_invocation_methods(self) -> None:
"""Sets the associated run and exception handling methods based on the invocation mode."""
@property
def invoke_dbt(self) -> Callable[..., FullOutputSubprocessResult | dbtRunnerResult]:
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.invoke_dbt = self.run_subprocess
self.handle_exception = self.handle_exception_subprocess
return self.run_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.invoke_dbt = self.run_dbt_runner
self.handle_exception = self.handle_exception_dbt_runner
return self.run_dbt_runner
else:
raise ValueError(f"Invalid invocation mode: {self.invocation_mode}")

@property
def handle_exception(self) -> Callable[..., None]:
if self.invocation_mode == InvocationMode.SUBPROCESS:
return self.handle_exception_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
return self.handle_exception_dbt_runner
else:
raise ValueError(f"Invalid invocation mode: {self.invocation_mode}")

def _discover_invocation_mode(self) -> None:
"""Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will
Expand All @@ -209,7 +214,6 @@ def _discover_invocation_mode(self) -> None:
else:
self.invocation_mode = InvocationMode.DBT_RUNNER
self.log.info("dbtRunner is available. Using dbtRunner for invoking dbt.")
self._set_invocation_methods()

def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None:
if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code:
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def __init__(
self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter")

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd)
if self._py_bin is not None:
self.log.info(f"Using Python binary from virtualenv: {self._py_bin}")
command[0] = str(Path(self._py_bin).parent / "dbt")
self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd)
subprocess_result = self.subprocess_hook.run_command(
command=command,
env=env,
Expand Down
17 changes: 16 additions & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,26 @@ def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_me
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode
)
dbt_base_operator._set_invocation_methods()
assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method
assert dbt_base_operator.handle_exception.__name__ == handle_exception_method


def test_dbt_base_operator_invalid_invocation_mode():
"""Tests that an invalid invocation_mode raises a ValueError when invoke_dbt or handle_exception is accessed."""
operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
)
operator.invocation_mode = "invalid_mode"

with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"):
_ = operator.invoke_dbt

with pytest.raises(ValueError, match="Invalid invocation mode: invalid_mode"):
_ = operator.handle_exception


@pytest.mark.parametrize(
"can_import_dbt, invoke_dbt_method, handle_exception_method",
[
Expand Down

0 comments on commit 191c635

Please sign in to comment.