From b69318e81b2addef170325edac9b627635033210 Mon Sep 17 00:00:00 2001 From: Nancy Hung Date: Fri, 12 Jan 2024 23:06:28 -0800 Subject: [PATCH] Delta to JSONL conversion script cleanup and bug fix (#868) * Small test change * small cleanups * lint and precommit * lint and precommit * comments * another one * pr suggestion and use input param not args --- scripts/data_prep/convert_delta_to_json.py | 110 +++++++++++++-------- 1 file changed, 70 insertions(+), 40 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 029ce7f5c3..326b8e912f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -33,8 +33,8 @@ from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0' -MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0' +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2' log = logging.getLogger(__name__) @@ -377,64 +377,61 @@ def fetch( cursor.close() -def fetch_DT(args: Namespace) -> None: - """Fetch UC Delta Table to local as jsonl.""" - log.info(f'Start .... Convert delta to json') - - obj = urllib.parse.urlparse(args.json_output_folder) - if obj.scheme != '': - raise ValueError( - f'Check the json_output_folder and verify it is a local path!') - - if os.path.exists(args.json_output_folder): - if not os.path.isdir(args.json_output_folder) or os.listdir( - args.json_output_folder): - raise RuntimeError( - f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' - ) - - os.makedirs(args.json_output_folder, exist_ok=True) - - if not args.json_output_filename.endswith('.jsonl'): - raise ValueError('json_output_filename needs to be a jsonl file') - - log.info(f'Directory {args.json_output_folder} created.') +def validate_and_get_cluster_info(cluster_id: str, + databricks_host: str, + databricks_token: str, + http_path: Optional[str], + use_serverless: bool = False) -> tuple: + """Validate and get cluster info for running the Delta to JSONL conversion. + Args: + cluster_id (str): cluster id to validate and fetch additional info for + databricks_host (str): databricks host name + databricks_token (str): databricks auth token + http_path (Optional[str]): http path to use for sql connect + use_serverless (bool): whether to use serverless or not + """ method = 'dbsql' dbsql = None sparkSession = None - if args.use_serverless: + if use_serverless: method = 'dbconnect' else: w = WorkspaceClient() - res = w.clusters.get(cluster_id=args.cluster_id) - runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0').replace('x', '0') + res = w.clusters.get(cluster_id=cluster_id) + if res is None: + raise ValueError( + f'Cluster id {cluster_id} does not exist. Check cluster id and try again!' + ) + stripped_runtime = re.sub( + r'[a-zA-Z]', '', + res.spark_version.split('-scala')[0].replace('x-snapshot', '')) + runtime_version = re.sub(r'.-+$', '', stripped_runtime) if version.parse(runtime_version) < version.parse( MINIMUM_SQ_CONNECT_DBR_VERSION): raise ValueError( f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' ) - if args.http_path is None and version.parse( + if http_path is None and version.parse( runtime_version) >= version.parse( MINIMUM_DB_CONNECT_DBR_VERSION): method = 'dbconnect' if method == 'dbconnect': try: - if args.use_serverless: + if use_serverless: session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host( - args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + databricks_host).token(databricks_token).header( 'x-databricks-session-id', session_id).getOrCreate() else: sparkSession = DatabricksSession.builder.remote( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id).getOrCreate() + host=databricks_host, + token=databricks_token, + cluster_id=cluster_id).getOrCreate() except Exception as e: raise RuntimeError( @@ -444,15 +441,47 @@ def fetch_DT(args: Namespace) -> None: try: dbsql = sql.connect( server_hostname=re.compile(r'^https?://').sub( - '', args.DATABRICKS_HOST).strip( + '', databricks_host).strip( ), # sqlconnect hangs if hostname starts with https - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN, + http_path=http_path, + access_token=databricks_token, ) except Exception as e: raise RuntimeError( 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' ) from e + return method, dbsql, sparkSession + + +def fetch_DT(args: Namespace) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(args.json_output_folder) + if obj.scheme != '': + raise ValueError( + f'Check the json_output_folder and verify it is a local path!') + + if os.path.exists(args.json_output_folder): + if not os.path.isdir(args.json_output_folder) or os.listdir( + args.json_output_folder): + raise RuntimeError( + f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(args.json_output_folder, exist_ok=True) + + if not args.json_output_filename.endswith('.jsonl'): + raise ValueError('json_output_filename needs to be a jsonl file') + + log.info(f'Directory {args.json_output_folder} created.') + + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=args.cluster_id, + databricks_host=args.DATABRICKS_HOST, + databricks_token=args.DATABRICKS_TOKEN, + http_path=args.http_path, + use_serverless=args.use_serverless) fetch(method, args.delta_table_name, args.json_output_folder, args.batch_size, args.processes, sparkSession, dbsql) @@ -494,9 +523,8 @@ def fetch_DT(args: Namespace) -> None: help='number of processes allowed to use') parser.add_argument( '--cluster_id', - required=True, + required=False, type=str, - default=None, help= 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' ) @@ -513,7 +541,9 @@ def fetch_DT(args: Namespace) -> None: required=False, type=str, default='train-00000-of-00001.jsonl', - help='The combined final jsonl that combines all partitioned jsonl') + help= + 'The name of the combined final jsonl that combines all partitioned jsonl' + ) args = parser.parse_args() from databricks.sdk import WorkspaceClient