Skip to content

Commit

Permalink
poll .GetBatch() instead of using operation.result() (#929)
Browse files Browse the repository at this point in the history
* re:PR https://github.com/dbt-labs/dbt-bigquery/pull/840/files

* adding back comment # check if job failed

* adding changelog

* precommit code format

* sleep(2) first in the while loop before the request to eliminate the last 2 seconds sleep if the response is in one of the 3 options

* removing empty spaces

* update batch request to handle `GetBatchRequest`

* conditionally run python model tests and factor out batch functions to own module

* Move events to common

* fix import

* fix mistaken import change

* update unit test

* clean up and typing

---------

Co-authored-by: Zi Wang <[email protected]>
Co-authored-by: wazi55 <[email protected]>
Co-authored-by: Anders <[email protected]>
Co-authored-by: Mike Alfare <[email protected]>
  • Loading branch information
5 people authored Sep 26, 2023
1 parent 63ae274 commit e5a89af
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230721-101041.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Serverless Spark to Poll with .GetBatch() instead of using operation.result()
time: 2023-07-21T10:10:41.64843-07:00
custom:
Author: wazi55
Issue: "734"
21 changes: 21 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:

outputs:
matrix: ${{ steps.generate-matrix.outputs.result }}
run-python-tests: ${{ steps.filter.outputs.bigquery-python }}

steps:
- name: Check out the repository (non-PR)
Expand Down Expand Up @@ -96,6 +97,11 @@ jobs:
- 'dbt/**'
- 'tests/**'
- 'dev-requirements.txt'
bigquery-python:
- 'dbt/adapters/bigquery/dataproc/**'
- 'dbt/adapters/bigquery/python_submissions.py'
- 'dbt/include/bigquery/python_model/**'
- name: Generate integration test matrix
id: generate-matrix
uses: actions/github-script@v6
Expand Down Expand Up @@ -186,6 +192,21 @@ jobs:
GCS_BUCKET: dbt-ci
run: tox -- --ddtrace

# python models tests are slow so we only want to run them if we're changing them
- name: Run tox (python models)
if: needs.test-metadata.outputs.run-python-tests == 'true'
env:
BIGQUERY_TEST_SERVICE_ACCOUNT_JSON: ${{ secrets.BIGQUERY_TEST_SERVICE_ACCOUNT_JSON }}
BIGQUERY_TEST_ALT_DATABASE: ${{ secrets.BIGQUERY_TEST_ALT_DATABASE }}
BIGQUERY_TEST_NO_ACCESS_DATABASE: ${{ secrets.BIGQUERY_TEST_NO_ACCESS_DATABASE }}
DBT_TEST_USER_1: group:[email protected]
DBT_TEST_USER_2: group:[email protected]
DBT_TEST_USER_3: serviceAccount:[email protected]
DATAPROC_REGION: us-central1
DATAPROC_CLUSTER_NAME: dbt-test-1
GCS_BUCKET: dbt-ci
run: tox -e python-tests -- --ddtrace

- uses: actions/upload-artifact@v3
if: always()
with:
Expand Down
Empty file.
67 changes: 67 additions & 0 deletions dbt/adapters/bigquery/dataproc/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Union, Dict

import time
from datetime import datetime
from google.cloud.dataproc_v1 import (
CreateBatchRequest,
BatchControllerClient,
Batch,
GetBatchRequest,
)
from google.protobuf.json_format import ParseDict

from dbt.adapters.bigquery.connections import DataprocBatchConfig

_BATCH_RUNNING_STATES = [Batch.State.PENDING, Batch.State.RUNNING]
DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar"


def create_batch_request(
batch: Batch, batch_id: str, project: str, region: str
) -> CreateBatchRequest:
return CreateBatchRequest(
parent=f"projects/{project}/locations/{region}", # type: ignore
batch_id=batch_id, # type: ignore
batch=batch, # type: ignore
)


def poll_batch_job(
parent: str, batch_id: str, job_client: BatchControllerClient, timeout: int
) -> Batch:
batch_name = "".join([parent, "/batches/", batch_id])
state = Batch.State.PENDING
response = None
run_time = 0
while state in _BATCH_RUNNING_STATES and run_time < timeout:
time.sleep(1)
response = job_client.get_batch( # type: ignore
request=GetBatchRequest(name=batch_name), # type: ignore
)
run_time = datetime.now().timestamp() - response.create_time.timestamp() # type: ignore
state = response.state
if not response:
raise ValueError("No response from Dataproc")
if state != Batch.State.SUCCEEDED:
if run_time >= timeout:
raise ValueError(
f"Operation did not complete within the designated timeout of {timeout} seconds."
)
else:
raise ValueError(response.state_message)
return response


def update_batch_from_config(config_dict: Union[Dict, DataprocBatchConfig], target: Batch):
try:
# updates in place
ParseDict(config_dict, target._pb)
except Exception as e:
docurl = (
"https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1"
"#google.cloud.dataproc.v1.Batch"
)
raise ValueError(
f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}"
) from e
return target
61 changes: 26 additions & 35 deletions dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
from google.api_core.future.polling import POLLING_PREDICATE

from dbt.adapters.bigquery import BigQueryConnectionManager, BigQueryCredentials
from dbt.adapters.bigquery.connections import DataprocBatchConfig
from google.api_core import retry
from google.api_core.client_options import ClientOptions
from google.cloud import storage, dataproc_v1 # type: ignore
from google.protobuf.json_format import ParseDict
from google.cloud.dataproc_v1.types.batches import Batch

from dbt.adapters.bigquery.dataproc.batch import (
create_batch_request,
poll_batch_job,
DEFAULT_JAR_FILE_URI,
update_batch_from_config,
)

OPERATION_RETRY_TIME = 10

Expand Down Expand Up @@ -102,8 +108,8 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
"job": job,
}
)
response = operation.result(polling=self.result_polling_policy)
# check if job failed
response = operation.result(polling=self.result_polling_policy)
if response.status.state == 6:
raise ValueError(response.status.details)
return response
Expand All @@ -118,21 +124,22 @@ def _get_job_client(self) -> dataproc_v1.BatchControllerClient:
def _get_batch_id(self) -> str:
return self.parsed_model["config"].get("batch_id")

def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
batch = self._configure_batch()
parent = f"projects/{self.credential.execution_project}/locations/{self.credential.dataproc_region}"

request = dataproc_v1.CreateBatchRequest(
parent=parent,
batch=batch,
batch_id=self._get_batch_id(),
)
def _submit_dataproc_job(self) -> Batch:
batch_id = self._get_batch_id()
request = create_batch_request(
batch=self._configure_batch(),
batch_id=batch_id,
region=self.credential.dataproc_region, # type: ignore
project=self.credential.execution_project, # type: ignore
) # type: ignore
# make the request
operation = self.job_client.create_batch(request=request) # type: ignore
# this takes quite a while, waiting on GCP response to resolve
# (not a google-api-core issue, more likely a dataproc serverless issue)
response = operation.result(polling=self.result_polling_policy)
return response
self.job_client.create_batch(request=request) # type: ignore
return poll_batch_job(
parent=request.parent,
batch_id=batch_id,
job_client=self.job_client, # type: ignore
timeout=self.timeout,
)
# there might be useful results here that we can parse and return
# Dataproc job output is saved to the Cloud Storage bucket
# allocated to the job. Use regex to obtain the bucket and blob info.
Expand Down Expand Up @@ -163,27 +170,11 @@ def _configure_batch(self):
batch.pyspark_batch.main_python_file_uri = self.gcs_location
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
DEFAULT_JAR_FILE_URI,
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]

# Apply configuration from dataproc_batch key, possibly overriding defaults.
if self.credential.dataproc_batch:
self._update_batch_from_config(self.credential.dataproc_batch, batch)
batch = update_batch_from_config(self.credential.dataproc_batch, batch)
return batch

@classmethod
def _update_batch_from_config(
cls, config_dict: Union[Dict, DataprocBatchConfig], target: dataproc_v1.Batch
):
try:
# updates in place
ParseDict(config_dict, target._pb)
except Exception as e:
docurl = (
"https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1"
"#google.cloud.dataproc.v1.Batch"
)
raise ValueError(
f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}"
) from e
2 changes: 1 addition & 1 deletion tests/functional/adapter/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def model(dbt, spark):
"""

models__python_array_batch_id_python = """
import pandas
import pandas as pd
def model(dbt, spark):
random_array = [
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_configure_dataproc_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import patch

from dbt.adapters.bigquery.python_submissions import ServerlessDataProcHelper
from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config
from google.cloud import dataproc_v1

from .test_bigquery_adapter import BaseTestBigQueryAdapter
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults):

batch = dataproc_v1.Batch()

ServerlessDataProcHelper._update_batch_from_config(raw_batch_config, batch)
batch = update_batch_from_config(raw_batch_config, batch)

def to_str_values(d):
"""google's protobuf types expose maps as dict[str, str]"""
Expand Down

0 comments on commit e5a89af

Please sign in to comment.