Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update
Browse files Browse the repository at this point in the history
xiaohanzhan-db committed Dec 5, 2023
1 parent a0c4249 commit 3951e83
Showing 1 changed file with 59 additions and 28 deletions.
87 changes: 59 additions & 28 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,16 @@
import os
import time

import urllib.parse
import pandas as pd
from databricks import sql
from typing import Any, Optional, List
from typing import Any, Optional, List, Tuple
from databricks.connect import DatabricksSession
from uuid import uuid4
from pyspark.sql.types import Row
import concurrent.futures
from multiprocessing import Pool
import subprocess

log = logging.getLogger(__name__)

@@ -37,10 +40,38 @@ def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optio

return None


def fetch_data_starargs(args: Tuple):
return fetch_data(*args)

def fetch_data(method, cursor, sparkSession, s, e, order_by, tablename, columns_str, json_output_path):
query = f"""
WITH NumberedRows AS (
SELECT
*,
ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn
FROM
{tablename}
)
SELECT {columns_str}
FROM NumberedRows
WHERE rn BETWEEN {s+1} AND {e}"""

if method == 'dbconnect':
pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas()
elif method == 'dbsql':
ans = run_query(query, method, cursor, sparkSession, collect=True)
pdf = pd.DataFrame.from_dict([row.asDict() for row in ans])

pdf.to_json(os.path.join(json_output_path,
f'part_{s+1}_{e}.json'))


def fetch(method,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20,
processes = 1,
sparkSession = None,
dbsql = None,
):
@@ -66,34 +97,24 @@ def fetch(method,
except Exception as e:
raise RuntimeError(f"Error in get columns from {tablename}. Restart sparkSession and try again") from e

def fetch_data(s, e, order_by, tablename, json_output_path):
query = f"""
WITH NumberedRows AS (
SELECT
*,
ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn
FROM
{tablename}
)
SELECT {columns_str}
FROM NumberedRows
WHERE rn BETWEEN {start+1} AND {end}"""

if method == 'dbconnect':
pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas()
elif method == 'dbsql':
ans = run_query(query, method, cursor, sparkSession, collect=True)
pdf = pd.DataFrame.from_dict([row.asDict() for row in ans])

pdf.to_json(os.path.join(json_output_path,
f'part_{s+1}_{e}.json'))

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
obj = urllib.parse.urlparse(json_output_path)

if method == 'dbconnect':
df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False)
print('processes = ', processes)

dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/')
df.repartition(processes).write.mode("overwrite").json(dbfs_cache)
print(f"downloading from {dbfs_cache} to {json_output_path}")
subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True)
subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True)

elif method == 'dbsql':
ans = run_query(query, method, cursor, sparkSession, collect=True)
pdf = pd.DataFrame.from_dict([row.asDict() for row in ans])
for start in range(0, total_rows, batch_size):
end = min(start + batch_size, total_rows)
futures.append(executor.submit(fetch_data, start, end, order_by, tablename, json_output_path))

fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path)

if cursor is not None:
cursor.close()
@@ -106,6 +127,10 @@ def fetch_DT(*args: Any, **kwargs: Any):
args = args[0]
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_path)
if obj.scheme != '':
raise ValueError(f"We don't support writing to remote yet in this script!")

if os.path.exists(args.json_output_path):
if not os.path.isdir(args.json_output_path) or os.listdir(
args.json_output_path):
@@ -138,7 +163,7 @@ def fetch_DT(*args: Any, **kwargs: Any):
session_id = str(uuid4())
sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate()

fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, sparkSession, dbsql)
fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.processes, sparkSession, dbsql)

if dbsql is not None:
dbsql.close()
@@ -175,6 +200,12 @@ def fetch_DT(*args: Any, **kwargs: Any):
default=1<<20,
help=
'chunk of rows to transmit a time')
parser.add_argument('--processes',
required=False,
type=int,
default=1,
help=
'number of processes allowed to use')
parser.add_argument('--debug', type=bool, required=False, default=False)
args = parser.parse_args()

0 comments on commit 3951e83

Please sign in to comment.