Skip to content

Commit

Permalink
Enhancement/refactor python submission (#452)
Browse files Browse the repository at this point in the history
* refactor and move common logic to core
  • Loading branch information
ChenyuLInx authored Sep 15, 2022
1 parent 60f47d5 commit 36bbe0d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 49 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20220912-104517.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Better interface for python submission
time: 2022-09-12T10:45:17.226481-07:00
custom:
Author: ChenyuLInx
Issue: "452"
PR: "452"
43 changes: 20 additions & 23 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from concurrent.futures import Future
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union, Type
from typing_extensions import TypeAlias

import agate
Expand All @@ -10,14 +10,17 @@
import dbt
import dbt.exceptions

from dbt.adapters.base import AdapterConfig
from dbt.adapters.base.impl import catch_as_completed, log_code_execution
from dbt.adapters.base.meta import available
from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed
from dbt.contracts.connection import AdapterResponse
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.spark import SparkConnectionManager
from dbt.adapters.spark import SparkRelation
from dbt.adapters.spark import SparkColumn
from dbt.adapters.spark.python_submissions import PYTHON_SUBMISSION_HELPERS
from dbt.adapters.spark.python_submissions import (
DBNotebookPythonJobHelper,
DBCommandsApiPythonJobHelper,
)
from dbt.adapters.base import BaseRelation
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER
from dbt.events import AdapterLogger
Expand Down Expand Up @@ -369,26 +372,20 @@ def run_sql_for_tests(self, sql, fetch, conn):
finally:
conn.transaction_open = False

@available.parse_none
@log_code_execution
def submit_python_job(self, parsed_model: dict, compiled_code: str, timeout=None):
# TODO improve the typing here. N.B. Jinja returns a `jinja2.runtime.Undefined` instead
# of `None` which evaluates to True!

# TODO limit this function to run only when doing the materialization of python nodes
# assuming that for python job running over 1 day user would mannually overwrite this
submission_method = parsed_model["config"].get("submission_method", "commands")
if submission_method not in PYTHON_SUBMISSION_HELPERS:
raise NotImplementedError(
"Submission method {} is not supported".format(submission_method)
)
job_helper = PYTHON_SUBMISSION_HELPERS[submission_method](
parsed_model, self.connections.profile.credentials
)
job_helper.submit(compiled_code)
# we don't really get any useful information back from the job submission other than success
def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse:
return self.connections.get_response(None)

@property
def default_python_submission_method(self) -> str:
return "commands"

@property
def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
return {
"notebook": DBNotebookPythonJobHelper,
"commands": DBCommandsApiPythonJobHelper,
}

def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
grants_dict: Dict[str, List[str]] = {}
for row in grants_table:
Expand Down
48 changes: 22 additions & 26 deletions dbt/adapters/spark/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import uuid

import dbt.exceptions
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.spark import SparkCredentials

DEFAULT_POLLING_INTERVAL = 3
DEFAULT_POLLING_INTERVAL = 5
SUBMISSION_LANGUAGE = "python"
DEFAULT_TIMEOUT = 60 * 60 * 24


class BasePythonJobHelper:
def __init__(self, parsed_model, credentials):
class BaseDatabricksHelper(PythonJobHelper):
def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None:
self.check_credentials(credentials)
self.credentials = credentials
self.identifier = parsed_model["alias"]
Expand All @@ -21,18 +23,18 @@ def __init__(self, parsed_model, credentials):
self.timeout = self.get_timeout()
self.polling_interval = DEFAULT_POLLING_INTERVAL

def get_timeout(self):
def get_timeout(self) -> int:
timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT)
if timeout <= 0:
raise ValueError("Timeout must be a positive integer")
return timeout

def check_credentials(self, credentials):
def check_credentials(self, credentials: SparkCredentials) -> None:
raise NotImplementedError(
"Overwrite this method to check specific requirement for current submission method"
)

def submit(self, compiled_code):
def submit(self, compiled_code: str) -> None:
raise NotImplementedError(
"BasePythonJobHelper is an abstract class and you should implement submit method."
)
Expand All @@ -45,7 +47,7 @@ def polling(
terminal_states,
expected_end_state,
get_state_msg_func,
):
) -> Dict:
state = None
start = time.time()
exceeded_timeout = False
Expand All @@ -54,7 +56,7 @@ def polling(
if time.time() - start > self.timeout:
exceeded_timeout = True
break
# TODO should we do exponential backoff?
# should we do exponential backoff?
time.sleep(self.polling_interval)
response = status_func(**status_func_kwargs)
state = get_state_func(response)
Expand All @@ -68,16 +70,16 @@ def polling(
return response


class DBNotebookPythonJobHelper(BasePythonJobHelper):
def __init__(self, parsed_model, credentials):
class DBNotebookPythonJobHelper(BaseDatabricksHelper):
def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None:
super().__init__(parsed_model, credentials)
self.auth_header = {"Authorization": f"Bearer {self.credentials.token}"}

def check_credentials(self, credentials):
def check_credentials(self, credentials) -> None:
if not credentials.user:
raise ValueError("Databricks user is required for notebook submission method.")

def _create_work_dir(self, path):
def _create_work_dir(self, path: str) -> None:
response = requests.post(
f"https://{self.credentials.host}/api/2.0/workspace/mkdirs",
headers=self.auth_header,
Expand All @@ -90,7 +92,7 @@ def _create_work_dir(self, path):
f"Error creating work_dir for python notebooks\n {response.content!r}"
)

def _upload_notebook(self, path, compiled_code):
def _upload_notebook(self, path: str, compiled_code: str) -> None:
b64_encoded_content = base64.b64encode(compiled_code.encode()).decode()
response = requests.post(
f"https://{self.credentials.host}/api/2.0/workspace/import",
Expand All @@ -108,7 +110,7 @@ def _upload_notebook(self, path, compiled_code):
f"Error creating python notebook.\n {response.content!r}"
)

def _submit_notebook(self, path):
def _submit_notebook(self, path: str) -> str:
submit_response = requests.post(
f"https://{self.credentials.host}/api/2.1/jobs/runs/submit",
headers=self.auth_header,
Expand All @@ -126,7 +128,7 @@ def _submit_notebook(self, path):
)
return submit_response.json()["run_id"]

def submit(self, compiled_code):
def submit(self, compiled_code: str) -> None:
# it is safe to call mkdirs even if dir already exists and have content inside
work_dir = f"/Users/{self.credentials.user}/{self.schema}/"
self._create_work_dir(work_dir)
Expand Down Expand Up @@ -167,7 +169,7 @@ def submit(self, compiled_code):


class DBContext:
def __init__(self, credentials):
def __init__(self, credentials: SparkCredentials) -> None:
self.auth_header = {"Authorization": f"Bearer {credentials.token}"}
self.cluster = credentials.cluster
self.host = credentials.host
Expand Down Expand Up @@ -206,7 +208,7 @@ def destroy(self, context_id: str) -> str:


class DBCommand:
def __init__(self, credentials):
def __init__(self, credentials: SparkCredentials) -> None:
self.auth_header = {"Authorization": f"Bearer {credentials.token}"}
self.cluster = credentials.cluster
self.host = credentials.host
Expand Down Expand Up @@ -247,12 +249,12 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]:
return response.json()


class DBCommandsApiPythonJobHelper(BasePythonJobHelper):
def check_credentials(self, credentials):
class DBCommandsApiPythonJobHelper(BaseDatabricksHelper):
def check_credentials(self, credentials: SparkCredentials) -> None:
if not credentials.cluster:
raise ValueError("Databricks cluster is required for commands submission method.")

def submit(self, compiled_code):
def submit(self, compiled_code: str) -> None:
context = DBContext(self.credentials)
command = DBCommand(self.credentials)
context_id = context.create()
Expand All @@ -276,9 +278,3 @@ def submit(self, compiled_code):
)
finally:
context.destroy(context_id)


PYTHON_SUBMISSION_HELPERS = {
"notebook": DBNotebookPythonJobHelper,
"commands": DBCommandsApiPythonJobHelper,
}

0 comments on commit 36bbe0d

Please sign in to comment.