From 36bbe0de5ec6069384c9a754ada85588fe032511 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Thu, 15 Sep 2022 08:20:42 -0700 Subject: [PATCH] Enhancement/refactor python submission (#452) * refactor and move common logic to core --- .../Under the Hood-20220912-104517.yaml | 7 +++ dbt/adapters/spark/impl.py | 43 ++++++++--------- dbt/adapters/spark/python_submissions.py | 48 +++++++++---------- 3 files changed, 49 insertions(+), 49 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20220912-104517.yaml diff --git a/.changes/unreleased/Under the Hood-20220912-104517.yaml b/.changes/unreleased/Under the Hood-20220912-104517.yaml new file mode 100644 index 000000000..e45c97bf0 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20220912-104517.yaml @@ -0,0 +1,7 @@ +kind: Under the Hood +body: Better interface for python submission +time: 2022-09-12T10:45:17.226481-07:00 +custom: + Author: ChenyuLInx + Issue: "452" + PR: "452" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index c228fc03d..77b1e4b5a 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,7 +1,7 @@ import re from concurrent.futures import Future from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union, Type from typing_extensions import TypeAlias import agate @@ -10,14 +10,17 @@ import dbt import dbt.exceptions -from dbt.adapters.base import AdapterConfig -from dbt.adapters.base.impl import catch_as_completed, log_code_execution -from dbt.adapters.base.meta import available +from dbt.adapters.base import AdapterConfig, PythonJobHelper +from dbt.adapters.base.impl import catch_as_completed +from dbt.contracts.connection import AdapterResponse from dbt.adapters.sql import SQLAdapter from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation from dbt.adapters.spark import SparkColumn -from dbt.adapters.spark.python_submissions import PYTHON_SUBMISSION_HELPERS +from dbt.adapters.spark.python_submissions import ( + DBNotebookPythonJobHelper, + DBCommandsApiPythonJobHelper, +) from dbt.adapters.base import BaseRelation from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger @@ -369,26 +372,20 @@ def run_sql_for_tests(self, sql, fetch, conn): finally: conn.transaction_open = False - @available.parse_none - @log_code_execution - def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None): - # TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead - # of `None` which evaluates to True! - - # TODO limit this function to run only when doing the materialization of python nodes - # assuming that for python job running over 1 day user would mannually overwrite this - submission_method = parsed_model["config"].get("submission_method", "commands") - if submission_method not in PYTHON_SUBMISSION_HELPERS: - raise NotImplementedError( - "Submission method {} is not supported".format(submission_method) - ) - job_helper = PYTHON_SUBMISSION_HELPERS[submission_method]( - parsed_model, self.connections.profile.credentials - ) - job_helper.submit(compiled_code) - # we don't really get any useful information back from the job submission other than success + def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse: return self.connections.get_response(None) + @property + def default_python_submission_method(self) -> str: + return "commands" + + @property + def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: + return { + "notebook": DBNotebookPythonJobHelper, + "commands": DBCommandsApiPythonJobHelper, + } + def standardize_grants_dict(self, grants_table: agate.Table) -> dict: grants_dict: Dict[str, List[str]] = {} for row in grants_table: diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index ea172ef03..5ee4adb18 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -5,14 +5,16 @@ import uuid import dbt.exceptions +from dbt.adapters.base import PythonJobHelper +from dbt.adapters.spark import SparkCredentials -DEFAULT_POLLING_INTERVAL = 3 +DEFAULT_POLLING_INTERVAL = 5 SUBMISSION_LANGUAGE = "python" DEFAULT_TIMEOUT = 60 * 60 * 24 -class BasePythonJobHelper: - def __init__(self, parsed_model, credentials): +class BaseDatabricksHelper(PythonJobHelper): + def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None: self.check_credentials(credentials) self.credentials = credentials self.identifier = parsed_model["alias"] @@ -21,18 +23,18 @@ def __init__(self, parsed_model, credentials): self.timeout = self.get_timeout() self.polling_interval = DEFAULT_POLLING_INTERVAL - def get_timeout(self): + def get_timeout(self) -> int: timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) if timeout <= 0: raise ValueError("Timeout must be a positive integer") return timeout - def check_credentials(self, credentials): + def check_credentials(self, credentials: SparkCredentials) -> None: raise NotImplementedError( "Overwrite this method to check specific requirement for current submission method" ) - def submit(self, compiled_code): + def submit(self, compiled_code: str) -> None: raise NotImplementedError( "BasePythonJobHelper is an abstract class and you should implement submit method." ) @@ -45,7 +47,7 @@ def polling( terminal_states, expected_end_state, get_state_msg_func, - ): + ) -> Dict: state = None start = time.time() exceeded_timeout = False @@ -54,7 +56,7 @@ def polling( if time.time() - start > self.timeout: exceeded_timeout = True break - # TODO should we do exponential backoff? + # should we do exponential backoff? time.sleep(self.polling_interval) response = status_func(**status_func_kwargs) state = get_state_func(response) @@ -68,16 +70,16 @@ def polling( return response -class DBNotebookPythonJobHelper(BasePythonJobHelper): - def __init__(self, parsed_model, credentials): +class DBNotebookPythonJobHelper(BaseDatabricksHelper): + def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None: super().__init__(parsed_model, credentials) self.auth_header = {"Authorization": f"Bearer {self.credentials.token}"} - def check_credentials(self, credentials): + def check_credentials(self, credentials) -> None: if not credentials.user: raise ValueError("Databricks user is required for notebook submission method.") - def _create_work_dir(self, path): + def _create_work_dir(self, path: str) -> None: response = requests.post( f"https://{self.credentials.host}/api/2.0/workspace/mkdirs", headers=self.auth_header, @@ -90,7 +92,7 @@ def _create_work_dir(self, path): f"Error creating work_dir for python notebooks\n {response.content!r}" ) - def _upload_notebook(self, path, compiled_code): + def _upload_notebook(self, path: str, compiled_code: str) -> None: b64_encoded_content = base64.b64encode(compiled_code.encode()).decode() response = requests.post( f"https://{self.credentials.host}/api/2.0/workspace/import", @@ -108,7 +110,7 @@ def _upload_notebook(self, path, compiled_code): f"Error creating python notebook.\n {response.content!r}" ) - def _submit_notebook(self, path): + def _submit_notebook(self, path: str) -> str: submit_response = requests.post( f"https://{self.credentials.host}/api/2.1/jobs/runs/submit", headers=self.auth_header, @@ -126,7 +128,7 @@ def _submit_notebook(self, path): ) return submit_response.json()["run_id"] - def submit(self, compiled_code): + def submit(self, compiled_code: str) -> None: # it is safe to call mkdirs even if dir already exists and have content inside work_dir = f"/Users/{self.credentials.user}/{self.schema}/" self._create_work_dir(work_dir) @@ -167,7 +169,7 @@ def submit(self, compiled_code): class DBContext: - def __init__(self, credentials): + def __init__(self, credentials: SparkCredentials) -> None: self.auth_header = {"Authorization": f"Bearer {credentials.token}"} self.cluster = credentials.cluster self.host = credentials.host @@ -206,7 +208,7 @@ def destroy(self, context_id: str) -> str: class DBCommand: - def __init__(self, credentials): + def __init__(self, credentials: SparkCredentials) -> None: self.auth_header = {"Authorization": f"Bearer {credentials.token}"} self.cluster = credentials.cluster self.host = credentials.host @@ -247,12 +249,12 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]: return response.json() -class DBCommandsApiPythonJobHelper(BasePythonJobHelper): - def check_credentials(self, credentials): +class DBCommandsApiPythonJobHelper(BaseDatabricksHelper): + def check_credentials(self, credentials: SparkCredentials) -> None: if not credentials.cluster: raise ValueError("Databricks cluster is required for commands submission method.") - def submit(self, compiled_code): + def submit(self, compiled_code: str) -> None: context = DBContext(self.credentials) command = DBCommand(self.credentials) context_id = context.create() @@ -276,9 +278,3 @@ def submit(self, compiled_code): ) finally: context.destroy(context_id) - - -PYTHON_SUBMISSION_HELPERS = { - "notebook": DBNotebookPythonJobHelper, - "commands": DBCommandsApiPythonJobHelper, -}