Skip to content

Commit

Permalink
Add db-connect
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohanzhan-db committed Dec 5, 2023
1 parent 23001cf commit 36f20ac
Showing 1 changed file with 89 additions and 17 deletions.
106 changes: 89 additions & 17 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,89 @@

log = logging.getLogger(__name__)

def fetch_DT(*args: Any, **kwargs: Any):
r"""Fetch Delta Table from UC and save to local
This can be called as
```
fetch_DT(server_hostname: str,
access_token: str,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20)
or
fetch_DT(server_hostname: str,
access_token: str,
http_path: str,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20)
```
Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql
"""
if 'http_path' not in args and 'http_path' not in kwargs:
return fetch_DT_with_dbconnect(*args, **kwargs)
else:
return fetch_DT_with_dbsql(*args, **kargs)

def fetch_DT_with_dbconnect(server_hostname: str,
access_token: str,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20):
"""Fetch UC delta table with databricks-connnect and convert them to json.
In the case when table is very large, we fetch batch_size rows a time.
Compared to fetch_DT_with_dbsql, this function does not need http_path.
"""
from databricks.connect import DatabricksSession
from uuid import uuid4

session_id = str(uuid4())
spark = DatabricksSession.builder.host("https://e2-dogfood.staging.cloud.databricks.com/").token("TOKEN").header("x-databricks-session-id", session_id).getOrCreate()

try:
ans = spark.sql(f"SELECT COUNT(*) FROM {tablename}").collect()
total_rows = [row.asDict() for row in ans][0].popitem()[1]

ans = spark.sql(f"SHOW COLUMNS IN {tablename}").collect()
order_by = [row.asDict() for row in ans][0].popitem()[1]

log.info(f'total_rows = {total_rows}')
log.info(f'order by column {order_by}')
except e:
raise RuntimeError(f"Error in get total rows / columns from {tablename}. Restart sparksession and try again") from e

for start in range(0, total_rows, batch_size):
end = min(start + batch_size, total_rows)

query = f"""
WITH NumberedRows AS (
SELECT
*,
ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn
FROM
{tablename}
)
SELECT *
FROM NumberedRows
WHERE rn BETWEEN {start+1} AND {end}"""

ans = spark.sql(query).collect()
df = spark.createDataFrame(ans).collect()
shard = os.path.join(json_output_path, f'shard_{start+1}_{end}.json')
shard.write.format('json').mode('overwrite').option('header', 'true').save('/tmp/new')

def stream_delta_to_json(server_hostname: str,
access_token: str,
http_path: str,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20):
"""Read UC delta table and convert it to json.

Save json files to local. In the case of table has more than batch_size
rows, read the table batch_size rows a time
def fetch_DT_with_dbsql(server_hostname: str,
access_token: str,
http_path: str,
tablename: str,
json_output_path: str,
batch_size: int = 1 << 20):
"""Fetch UC delta table locally as dataframes and convert them to json.
In the case when table is very large, we fetch batch_size rows a time.
"""
log.info(f'Start .... Convert delta to json')

Expand Down Expand Up @@ -122,18 +194,18 @@ def stream_delta_to_json(server_hostname: str,
parser.add_argument('--debug', type=bool, required=False, default=False)
args = parser.parse_args()

server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv(
'DATABRICKS_HOST')
access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv(
'DATABRICKS_TOKEN')
http_path = args.http_path
tablename = args.delta_table_name
json_output_path = args.json_output_path
#server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv(
# 'DATABRICKS_HOST')
#access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv(
# 'DATABRICKS_TOKEN')
#http_path = args.http_path
#tablename = args.delta_table_name
#json_output_path = args.json_output_path

tik = time.time()
print("start timer", tik)

stream_delta_to_json(server_hostname, access_token, http_path, tablename, json_output_path)
fetch_DT(*args)

print("end timer", time.time() - tik)

0 comments on commit 36f20ac

Please sign in to comment.