From f133723b1b593189ed6344d518e3dcd172b9b3b4 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Fri, 18 Oct 2024 19:51:26 +0200 Subject: [PATCH] [fix] Improve logic for local deployment of PythonScript (#168) * add spark_python_task support in Workflow class and update example workflows * fix: tests and expected bundle * feat: add adjustment of python file in case of local for spark_python * fix: adjust logic for workspace * feat: add common_task_parameters support * fix: missed slash * fix: improve tests * feat: update poetry lock * Feat/workflow parameters (#1) * Add workflows parameters * feat: add JobsParameters import in sample_workflows * fix: adjust logic for local deployment in PythonScript * fix: update version number to 0.11.0a0 in pyproject.toml --- brickflow/codegen/databricks_bundle.py | 27 +++++++++++++++++++++---- tests/codegen/test_databricks_bundle.py | 11 ++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index e677af6..42c02a8 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -51,6 +51,7 @@ Targets, Workspace, ) +from brickflow.cli.projects import MultiProjectManager, get_brickflow_root from brickflow.codegen import ( CodegenInterface, DatabricksDefaultClusterTagKeys, @@ -461,12 +462,30 @@ def adjust_file_path(self, file_path: str) -> str: ] ).replace("//", "/") - # Finds the start position of the project name in the given file path and calculates the cut position. - # - `file_path.find(self.project.name)`: Finds the start index of the project name in the file path. - # - `+ len(self.project.name) + 1`: Moves the start position to the character after the project name. + multi_project_manager = MultiProjectManager( + config_file_name=str(get_brickflow_root()) + ) + bf_project = multi_project_manager.get_project(self.project.name) + + start_index_of_project_root = file_path.find( + bf_project.path_from_repo_root_to_project_root + ) + + if start_index_of_project_root < 0: + raise ValueError( + f"Error while adjusting file path. " + f"Project root not found in the file path: {file_path}." + ) + + # Finds the start position of the path_from_repo_root_to_project_root in the given file path + # and calculates the cut position. + # - `file_path.find: Finds the start index of the project root in the file path. + # - `+ len + 1`: Moves the start position to the character after the project root. # - Adjusts the file path by appending the local bundle path to the cut file path. cut_file_path = file_path[ - file_path.find(self.project.name) + len(self.project.name) + 1 : + start_index_of_project_root + + len(bf_project.path_from_repo_root_to_project_root) + + 1 : ] file_path = ( bundle_files_local_path + file_path diff --git a/tests/codegen/test_databricks_bundle.py b/tests/codegen/test_databricks_bundle.py index 0192ed2..a734eff 100644 --- a/tests/codegen/test_databricks_bundle.py +++ b/tests/codegen/test_databricks_bundle.py @@ -82,6 +82,7 @@ class TestBundleCodegen(TestCase): BrickflowEnvVars.BRICKFLOW_PROJECT_TAGS.value: "tag1 = value1, tag2 =value2 ", # spaces will be trimmed }, ) + @patch("brickflow.codegen.databricks_bundle.MultiProjectManager") @patch("brickflow.engine.task.get_job_id", return_value=12345678901234.0) @patch("subprocess.check_output") @patch("brickflow.context.ctx.get_parameter") @@ -96,12 +97,16 @@ def test_generate_bundle_local( dbutils: Mock, sub_proc_mock: Mock, get_job_id_mock: Mock, + multi_project_manager_mock: Mock, ): dbutils.return_value = None sub_proc_mock.return_value = b"" bf_version_mock.return_value = "1.0.0" workspace_client = get_workspace_client_mock() get_job_id_mock.return_value = 12345678901234.0 + multi_project_manager_mock.return_value.get_project.return_value = MagicMock( + path_from_repo_root_to_project_root="test-project" + ) # get caller part breaks here with Project( "test-project", @@ -138,6 +143,7 @@ def test_generate_bundle_local( BrickflowEnvVars.BRICKFLOW_WORKFLOW_SUFFIX.value: "_suffix", }, ) + @patch("brickflow.codegen.databricks_bundle.MultiProjectManager") @patch("brickflow.engine.task.get_job_id", return_value=12345678901234.0) @patch("subprocess.check_output") @patch("brickflow.context.ctx.get_parameter") @@ -146,18 +152,23 @@ def test_generate_bundle_local( "brickflow.context.ctx.get_current_timestamp", MagicMock(return_value=1704067200000), ) + # @patch() def test_generate_bundle_local_prefix_suffix( self, bf_version_mock: Mock, dbutils: Mock, sub_proc_mock: Mock, get_job_id_mock: Mock, + multi_project_manager_mock: Mock, ): dbutils.return_value = None sub_proc_mock.return_value = b"" bf_version_mock.return_value = "1.0.0" workspace_client = get_workspace_client_mock() get_job_id_mock.return_value = 12345678901234.0 + multi_project_manager_mock.return_value.get_project.return_value = MagicMock( + path_from_repo_root_to_project_root="test-project" + ) # get caller part breaks here with Project( "test-project",