-
Notifications
You must be signed in to change notification settings - Fork 28
/
nds_maintenance.py
319 lines (299 loc) · 14.5 KB
/
nds_maintenance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# -----
#
# Certain portions of the contents of this file are derived from TPC-DS version 3.2.0
# (retrieved from www.tpc.org/tpc_documents_current_versions/current_specifications5.asp).
# Such portions are subject to copyrights held by Transaction Processing Performance Council (“TPC”)
# and licensed under the TPC EULA (a copy of which accompanies this file as “TPC EULA” and is also
# available at http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) (the “TPC EULA”).
#
# You may not use this file except in compliance with the TPC EULA.
# DISCLAIMER: Portions of this file is derived from the TPC-DS Benchmark and as such any results
# obtained using this file are not comparable to published TPC-DS Benchmark results, as the results
# obtained from using this file do not comply with the TPC-DS Benchmark.
#
import argparse
import csv
from datetime import datetime
import os
from pyspark.sql import SparkSession
from PysparkBenchReport import PysparkBenchReport
from check import check_json_summary_folder, get_abs_path
from nds_schema import get_maintenance_schemas
from nds_power import register_delta_tables
INSERT_FUNCS = [
'LF_CR',
'LF_CS',
'LF_I',
'LF_SR',
'LF_SS',
'LF_WR',
'LF_WS']
DELETE_FUNCS = [
'DF_CS',
'DF_SS',
'DF_WS']
INVENTORY_DELETE_FUNC = ['DF_I']
DM_FUNCS = INSERT_FUNCS + DELETE_FUNCS + INVENTORY_DELETE_FUNC
def get_delete_date(spark_session):
"""get delete dates for Data Maintenance. Each delete functions requires 3 tuples: (date1, date2)
Args:
spark_session (SparkSession): Spark session
Returns:
delete_dates_dict ({str: list[(date1, date2)]}): a dict contains date tuples for each delete functions
"""
delete_dates = spark_session.sql("select * from delete").collect()
inventory_delete_dates = spark_session.sql("select * from inventory_delete").collect()
date_dict = {}
date_dict['delete'] = [(row['date1'], row['date2']) for row in delete_dates]
date_dict['inventory_delete'] = [(row['date1'], row['date2']) for row in inventory_delete_dates]
return date_dict
def replace_date(query_list, date_tuple_list):
"""Replace the date keywords in DELETE queries. 3 date tuples will be applied to the delete query.
Args:
query_list ([str]): delete query list
date_tuple_list ([(str, str)]): actual delete date
"""
q_updated = []
for date_tuple in date_tuple_list:
date1 = datetime.strptime(date_tuple[0], "%Y-%m-%d")
date2 = datetime.strptime(date_tuple[1], "%Y-%m-%d")
if date1 > date2:
earlier = date_tuple[1]
later = date_tuple[0]
else:
earlier = date_tuple[0]
later = date_tuple[1]
for c in query_list:
c = c.replace("DATE1", earlier)
c = c.replace("DATE2", later)
q_updated.append(c)
return q_updated
def get_valid_query_names(spec_queries):
global DM_FUNCS
if spec_queries:
for q in spec_queries:
if q not in DM_FUNCS:
raise Exception(f"invalid Data Maintenance query: {q}. Valid are: {DM_FUNCS}")
DM_FUNCS = spec_queries
return DM_FUNCS
def create_spark_session(valid_queries, warehouse_path, warehouse_type):
if len(valid_queries) == 1:
app_name = "NDS - Data Maintenance - " + valid_queries[0]
else:
app_name = "NDS - Data Maintenance"
spark_session_builder = SparkSession.builder
if warehouse_type == "delta":
# now we only support managed table(by Hive Metastore) for Data Maintenance
spark_session_builder.config("spark.sql.catalogImplementation", "hive")
if warehouse_type == "iceberg":
spark_session_builder.config("spark.sql.catalog.spark_catalog.warehouse", warehouse_path)
spark_session = spark_session_builder.appName(app_name).getOrCreate()
return spark_session
def get_maintenance_queries(spark_session, folder, valid_queries):
"""get query content from DM query files
Args:
folder (str): folder to Data Maintenance query files
spec_queries (list[str]): specific target Data Maintenance queries
Returns:
dict{str: list[str]}: a dict contains Data Maintenance query name and its content.
"""
delete_date_dict = get_delete_date(spark_session)
folder_abs_path = get_abs_path(folder)
q_dict = {}
for q in valid_queries:
with open(folder_abs_path + '/' + q + '.sql', 'r') as f:
# file content e.g.
# " LICENSE CONTENT ... ;"
# " CREATE view ..... ; INSERT into .... ;"
# " DELETE from ..... ; DELETE FROM .... ;"
q_content = [ c + ';' for c in f.read().split(';')[1:-1]]
if q in DELETE_FUNCS:
# There're 3 date tuples to be replace for one DELETE function
# according to TPC-DS Spec 5.3.11
q_content = replace_date(q_content, delete_date_dict['delete'])
if q in INVENTORY_DELETE_FUNC:
q_content = replace_date(q_content, delete_date_dict['inventory_delete'])
q_dict[q] = q_content
return q_dict
def run_subquery_for_delta(spark_session, delete_query):
"""DeltaLake doesn't support DELETE with subquery, so run the subquery at first as workaround.
return: a query that can be run on Delta Lake after subquery replacement.
See issue: https://github.com/delta-io/delta/issues/730
Note this method is very tricky and is totally based on the query content itself.
TODO: remove this method when the issue above is resolved.
"""
# first strip out the license part
delete_query = delete_query.split('--')[-1]
if not "min" in delete_query:
# e.g. "delete ... in (select ...);"
subquery_start_pos = delete_query.find("(") + 1
subquery_end_pos = delete_query.find(")")
if subquery_start_pos == -1 or subquery_end_pos == -1:
raise Exception("invalid delete query")
subquery = delete_query[subquery_start_pos:subquery_end_pos]
subquery_df = spark_session.sql(subquery)
# only 1 column, so retrive directly at index 0
col_name = subquery_df.schema.fields[0].name
subquery_result = subquery_df.collect()
# form the string then drop "[" and "]"
subquery_result = str([row[col_name] for row in subquery_result])[1:-1]
final_query = delete_query.replace(subquery, subquery_result)
return final_query
else:
# e.g. "delete ... (select min(d_date_sk) ... )... and ... ( select max(d_date_sk) ... );"
# subquery_1 is between first "(" and second ")"
# subquery_2 is only different from subquery_1 in the "min" and "max" keyword.
subquery_start_pos1 = delete_query.find("(") + 1
first_right_parenthesis = delete_query.find(")")
subquery_end_pos1 = delete_query.find(")", first_right_parenthesis + 1)
subquery_1 = delete_query[subquery_start_pos1:subquery_end_pos1]
subquery_2 = subquery_1.replace("min", "max")
# result only 1 row.
subquery_1_result = str(spark_session.sql(subquery_1).collect()[0][0])
subquery_2_result = str(spark_session.sql(subquery_2).collect()[0][0])
final_query = delete_query.replace(
subquery_1, subquery_1_result).replace(
subquery_2, subquery_2_result)
return final_query
def run_dm_query(spark, query_list, query_name, warehouse_type):
"""Run data maintenance query.
For delete queries, they can run on Spark 3.2.2 but not Spark 3.2.1
See: https://issues.apache.org/jira/browse/SPARK-39454
See: data_maintenance/DF_*.sql for delete query details.
See data_maintenance/LF_*.sql for insert query details.
Args:
spark (SparkSession): SparkSession instance.
query_list ([str]): INSERT query list.
"""
for q in query_list:
if query_name in DELETE_FUNCS + INVENTORY_DELETE_FUNC and warehouse_type == "delta":
q = run_subquery_for_delta(spark, q)
spark.sql(q)
def run_query(spark_session,
query_dict,
time_log_output_path,
json_summary_folder,
property_file,
warehouse_path,
warehouse_type,
keep_sc,
delta_unmanaged=False):
# TODO: Duplicate code in nds_power.py. Refactor this part, make it general.
execution_time_list = []
check_json_summary_folder(json_summary_folder)
# Run query
total_time_start = datetime.now()
spark_app_id = spark_session.sparkContext.applicationId
DM_start = datetime.now()
if warehouse_type == 'delta' and delta_unmanaged:
execution_time_list = register_delta_tables(spark_session, warehouse_path, execution_time_list)
for query_name, q_content in query_dict.items():
# show query name in Spark web UI
spark_session.sparkContext.setJobGroup(query_name, query_name)
print(f"====== Run {query_name} ======")
q_report = PysparkBenchReport(spark_session, query_name)
summary = q_report.report_on(run_dm_query, spark_session,
q_content,
query_name,
warehouse_type)
print(f"Time taken: {summary['queryTimes']} millis for {query_name}")
execution_time_list.append((spark_app_id, query_name, summary['queryTimes']))
if json_summary_folder:
# property_file e.g.: "property/aqe-on.properties" or just "aqe-off.properties"
if property_file:
summary_prefix = os.path.join(
json_summary_folder, os.path.basename(property_file).split('.')[0])
else:
summary_prefix = os.path.join(json_summary_folder, '')
q_report.write_summary(prefix=summary_prefix)
if not keep_sc:
spark_session.sparkContext.stop()
DM_end = datetime.now()
DM_elapse = (DM_end - DM_start).total_seconds()
total_elapse = (DM_end - total_time_start).total_seconds()
print(f"====== Data Maintenance Start Time: {DM_start}")
print(f"====== Data Maintenance Time: {DM_elapse} s ======")
print(f"====== Total Time: {total_elapse} s ======")
execution_time_list.append(
(spark_app_id, "Data Maintenance Start Time", DM_start)
)
execution_time_list.append(
(spark_app_id, "Data Maintenance End Time", DM_end)
)
execution_time_list.append(
(spark_app_id, "Data Maintenance Time", DM_elapse))
execution_time_list.append(
(spark_app_id, "Total Time", total_elapse))
# write to local csv file
header = ["application_id", "query", "time/s"]
with open(time_log_output_path, 'w', encoding='UTF8') as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(execution_time_list)
def register_temp_views(spark_session, refresh_data_path):
refresh_tables = get_maintenance_schemas(True)
for table, schema in refresh_tables.items():
spark_session.read.option("delimiter", '|').option(
"header", "false").csv(refresh_data_path + '/' + table, schema=schema).createOrReplaceTempView(table)
if __name__ == "__main__":
parser = parser = argparse.ArgumentParser()
parser.add_argument('warehouse_path',
help='warehouse path for Data Maintenance test.')
parser.add_argument('refresh_data_path',
help='path to refresh data')
parser.add_argument('maintenance_queries_folder',
help='folder contains all NDS Data Maintenance queries. If ' +
'"--maintenance_queries" is not set, all queries under the folder will be' +
'executed.')
parser.add_argument('time_log',
help='path to execution time log, only support local path.',
default="")
parser.add_argument('--maintenance_queries',
type=lambda s: s.split(','),
help='specify Data Maintenance query names by a comma seprated string.' +
' e.g. "LF_CR,LF_CS"')
parser.add_argument('--property_file',
help='property file for Spark configuration.')
parser.add_argument('--json_summary_folder',
help='Empty folder/path (will create if not exist) to save JSON summary file for each query.')
parser.add_argument('--warehouse_type',
help='Type of the warehouse used for Data Maintenance test.',
choices=['iceberg', 'delta'],
default='iceberg')
parser.add_argument('--keep_sc',
action='store_true',
help='Keep SparkContext alive after running all queries. This is a ' +
'limitation on Databricks runtime environment. User should always attach ' +
'this flag when running on Databricks.')
parser.add_argument('--delta_unmanaged',
action='store_true',
help='Use unmanaged tables for DeltaLake. This is useful for testing DeltaLake without ' +
' leveraging a Metastore service.')
args = parser.parse_args()
valid_queries = get_valid_query_names(args.maintenance_queries)
spark_session = create_spark_session(valid_queries, args.warehouse_path, args.warehouse_type)
register_temp_views(spark_session, args.refresh_data_path)
query_dict = get_maintenance_queries(spark_session,
args.maintenance_queries_folder,
valid_queries)
run_query(spark_session, query_dict, args.time_log, args.json_summary_folder,
args.property_file, args.warehouse_path, args.warehouse_type, args.keep_sc,
args.delta_unmanaged)